diff --git a/component/taskd/taskd/go/backend_api.go b/component/taskd/taskd/go/backend_api.go index 54d07f333afa01711f80a0cb00bf40b493877393..e90585f99861fd1070494d1d142e8b3969156129 100644 --- a/component/taskd/taskd/go/backend_api.go +++ b/component/taskd/taskd/go/backend_api.go @@ -113,7 +113,7 @@ func StepOut() C.int { // //export InitTaskdManager func InitTaskdManager(configStr *C.char) C.int { - var config manager.Config + var config constant.Config if err := json.Unmarshal([]byte(C.GoString(configStr)), &config); err != nil { return C.int(1) } diff --git a/component/taskd/taskd/go/common/constant/const.go b/component/taskd/taskd/go/common/constant/const.go index 2aa72e3df8c94738f7a0ee9233a96b453676e908..e713bbc1f140c3b7f0245382157d895839a364f3 100644 --- a/component/taskd/taskd/go/common/constant/const.go +++ b/component/taskd/taskd/go/common/constant/const.go @@ -131,6 +131,9 @@ const ( ProfilingAllOnCmdCode = 711 ) +// DefaultBizType is net tool default biz type +const DefaultBizType = "default" + // RequestChanNum message handler chan number const RequestChanNum = 100 diff --git a/component/taskd/taskd/go/common/constant/type.go b/component/taskd/taskd/go/common/constant/type.go index 48342e5e25c0bb273b1dc9006da39ad74a669fa1..1882d5339929eba3eea6510a54537e5036b203ea 100644 --- a/component/taskd/taskd/go/common/constant/type.go +++ b/component/taskd/taskd/go/common/constant/type.go @@ -45,3 +45,29 @@ type ProfilingSwitch struct { type ProfilingWorkerState struct { state string } + +// ClusterInfo define the information from the cluster +type ClusterInfo struct { + // IP indicate cluster server ip + Ip string `json:"ip"` + // Port indicate cluster server port + Port string `json:"port"` + // Name indicate cluster server service name + Name string `json:"name"` + // Role + Role string `json:"role"` +} + +// Config define the configuration of manager +type Config struct { + // JobId indicate the id of the job where the manager is located + JobId string `json:"job_id"` + // NodeNums indicate the number of nodes where the manager is located + NodeNums int `json:"node_nums"` + // ProcPerNode indicate the number of business processes where the manager's job is located + ProcPerNode int `json:"proc_per_node"` + // PluginDir indicate the plugin dir + PluginDir string `json:"plugin_dir"` + // ClusterInfos indicate the information of cluster + ClusterInfos []ClusterInfo `json:"cluster_infos"` +} diff --git a/component/taskd/taskd/go/common/utils/utils.go b/component/taskd/taskd/go/common/utils/utils.go index 3325fd55951bc8e5b1256ce4be10c0a028b1a12e..c43b00e75a3e2ba6a5d412b00f59ca5a2aa359e9 100644 --- a/component/taskd/taskd/go/common/utils/utils.go +++ b/component/taskd/taskd/go/common/utils/utils.go @@ -255,7 +255,6 @@ func GetClusterdAddr() (string, error) { parsedIP := net.ParseIP(ipFromEnv) if parsedIP == nil { return "", fmt.Errorf("%s is NOT a valid IP address", ipFromEnv) - } return ipFromEnv + constant.ClusterdPort, nil } diff --git a/component/taskd/taskd/go/framework_backend/manager/application/businessStream.go b/component/taskd/taskd/go/framework_backend/manager/application/businessStream.go index 2f8d80987b80b34fa59042661904be41d049b31c..68625c9fe239716ed48dca91956474962c02a826 100644 --- a/component/taskd/taskd/go/framework_backend/manager/application/businessStream.go +++ b/component/taskd/taskd/go/framework_backend/manager/application/businessStream.go @@ -20,13 +20,12 @@ import ( "fmt" "strings" - "github.com/google/uuid" - "ascend-common/common-utils/hwlog" "taskd/common/constant" "taskd/framework_backend/manager/infrastructure" "taskd/framework_backend/manager/infrastructure/storage" "taskd/framework_backend/manager/service" + "taskd/framework_backend/manager/service/adaptor" "taskd/toolkit_backend/net/common" ) @@ -141,8 +140,7 @@ func (b *BusinessStreamProcessor) DistributeMsg(msgs []infrastructure.Msg) error continue } for _, receiver := range msg.Receiver { - if receiver == common.MgrRole { - b.DistributedMsgToMgr(msg) + if adaptor.DistributeMsg(msg, receiver) { continue } sendMsg, err := json.Marshal(msg.Body) @@ -156,17 +154,6 @@ func (b *BusinessStreamProcessor) DistributeMsg(msgs []infrastructure.Msg) error return nil } -// DistributedMsgToMgr distributed message to manager -func (b *BusinessStreamProcessor) DistributedMsgToMgr(msg infrastructure.Msg) { - b.MsgHandler.SendMsgToMgr(uuid.New().String(), constant.DefaultDomainName, - &common.Position{ - Role: common.MgrRole, - ServerRank: "0", - ProcessRank: "-1", - }, msg.Body) - hwlog.RunLog.Debugf("business handler send msg %v to mgr", msg.Body) -} - // DistributedMsgToOthers distributed message to others func (b *BusinessStreamProcessor) DistributedMsgToOthers(receiver string, sendMsg []byte) { var dst *common.Position @@ -185,6 +172,6 @@ func (b *BusinessStreamProcessor) DistributedMsgToOthers(receiver string, sendMs return } } - b.MsgHandler.SendMsgUseGrpc(constant.DefaultDomainName, string(sendMsg), dst) + b.MsgHandler.SendMsgUseGrpc(constant.DefaultBizType, string(sendMsg), dst) hwlog.RunLog.Debugf("business handler send msg %s to others", string(sendMsg)) } diff --git a/component/taskd/taskd/go/framework_backend/manager/application/msghandler.go b/component/taskd/taskd/go/framework_backend/manager/application/msghandler.go index 18d4c1611ba77ddea23c2735440763f0422812dc..7ada603ebd80e6f0ca0f3626fa6d051b797130a4 100644 --- a/component/taskd/taskd/go/framework_backend/manager/application/msghandler.go +++ b/component/taskd/taskd/go/framework_backend/manager/application/msghandler.go @@ -35,7 +35,6 @@ import ( type MsgHandlerInterface interface { GetDataPool() *storage.DataPool SendMsgUseGrpc(msgType string, msgBody string, dst *common.Position) - SendMsgToMgr(uuid string, bizType string, src *common.Position, msgBody storage.MsgBody) } // MsgHandler receive, send, process and store message info @@ -176,16 +175,6 @@ func (mhd *MsgHandler) SendMsgUseGrpc(msgType string, msgBody string, dst *commo } } -// SendMsgToMgr send message into manager message queue -func (mhd *MsgHandler) SendMsgToMgr(uuid string, bizType string, src *common.Position, msgBody storage.MsgBody) { - data := mhd.MsgQueue.NewMsg(uuid, bizType, src, msgBody) - err := mhd.MsgQueue.Enqueue(data) - if err != nil { - hwlog.RunLog.Errorf("enqueue failed: %v", err) - mhd.SendMsgUseGrpc(bizType, err.Error(), src) - } -} - // GetDataPool return data pool func (mhd *MsgHandler) GetDataPool() *storage.DataPool { return mhd.DataPool diff --git a/component/taskd/taskd/go/framework_backend/manager/application/msghandler_test.go b/component/taskd/taskd/go/framework_backend/manager/application/msghandler_test.go index d1addc0bfb89f77babc229f2886fa35dd91f4b6e..757d8d3ad0d60abeba46bbb5822485286a2ea56f 100644 --- a/component/taskd/taskd/go/framework_backend/manager/application/msghandler_test.go +++ b/component/taskd/taskd/go/framework_backend/manager/application/msghandler_test.go @@ -27,7 +27,6 @@ import ( "github.com/smartystreets/goconvey/convey" "ascend-common/common-utils/hwlog" - "taskd/common/constant" "taskd/framework_backend/manager/infrastructure/storage" "taskd/framework_backend/manager/service" "taskd/toolkit_backend/net" @@ -137,24 +136,3 @@ func TestSendMsgUseGrpc(t *testing.T) { convey.So(req.Dst, convey.ShouldEqual, testDst) }) } - -// TestSendMsgToMgr test manager send msg enqueue -func TestSendMsgToMgr(t *testing.T) { - convey.Convey("TestSendMsgToMgr manager send msg enqueue success", t, func() { - mhd := NewMsgHandler() - testSrc := &common.Position{Role: common.WorkerRole} - oldLength := len(mhd.MsgQueue.Queue) - mhd.SendMsgToMgr("test-uuid", "test-type", testSrc, storage.MsgBody{}) - convey.So(oldLength+1, convey.ShouldEqual, len(mhd.MsgQueue.Queue)) - }) - convey.Convey("TestSendMsgToMgr manager send msg enqueue fail", t, func() { - mhd := &MsgHandler{ - Sender: &service.MsgSender{RequestChan: make(chan service.SendGrpcMsg, constant.RequestChanNum)}, - MsgQueue: &storage.MsgQueue{Queue: make([]storage.BaseMessage, constant.MaxMsgQueueLength), - Mutex: sync.Mutex{}}, - } - testSrc := &common.Position{Role: common.WorkerRole} - mhd.SendMsgToMgr("test-uuid", "test-type", testSrc, storage.MsgBody{}) - convey.So(len(mhd.MsgQueue.Queue), convey.ShouldEqual, constant.MaxMsgQueueLength) - }) -} diff --git a/component/taskd/taskd/go/framework_backend/manager/manager.go b/component/taskd/taskd/go/framework_backend/manager/manager.go index 1992e91b9d7a2dd6b068f7fd81f6b8544b89d2e9..50ada6cca7b5541915693872dc4d992e94386501 100644 --- a/component/taskd/taskd/go/framework_backend/manager/manager.go +++ b/component/taskd/taskd/go/framework_backend/manager/manager.go @@ -18,52 +18,18 @@ package manager import ( "context" "fmt" - "io" - "sync/atomic" "time" - "github.com/google/uuid" - "google.golang.org/grpc" - "google.golang.org/grpc/credentials/insecure" - "ascend-common/common-utils/hwlog" - "clusterd/pkg/interface/grpc/profiling" - "clusterd/pkg/interface/grpc/recover" "taskd/common/constant" "taskd/common/utils" "taskd/framework_backend/manager/application" "taskd/framework_backend/manager/infrastructure/storage" - "taskd/toolkit_backend/net/common" + "taskd/framework_backend/manager/service/adaptor" ) -// ClusterInfo define the information from the cluster -type ClusterInfo struct { - // IP indicate cluster server ip - Ip string `json:"ip"` - // Port indicate cluster server port - Port string `json:"port"` - // Name indicate cluster server service name - Name string `json:"name"` - // Role - Role string `json:"role"` -} - -// Config define the configuration of manager -type Config struct { - // JobId indicate the id of the job where the manager is located - JobId string `json:"job_id"` - // NodeNums indicate the number of nodes where the manager is located - NodeNums int `json:"node_nums"` - // ProcPerNode indicate the number of business processes where the manager's job is located - ProcPerNode int `json:"proc_per_node"` - // PluginDir indicate the plugin dir - PluginDir string `json:"plugin_dir"` - // ClusterInfos indicate the information of cluster - ClusterInfos []ClusterInfo `json:"cluster_infos"` -} - // NewTaskDManager return taskd manager instance -func NewTaskDManager(config Config) *BaseManager { +func NewTaskDManager(config constant.Config) *BaseManager { return &BaseManager{ Config: config, } @@ -71,21 +37,13 @@ func NewTaskDManager(config Config) *BaseManager { // BaseManager the class taskd manager backend type BaseManager struct { - Config - BusinessHandler *application.BusinessStreamProcessor - MsgHd *application.MsgHandler - svcCtx context.Context - cancelFunc context.CancelFunc - profilingFromClusterD atomic.Bool + constant.Config + BusinessHandler *application.BusinessStreamProcessor + MsgHd *application.MsgHandler + svcCtx context.Context + cancelFunc context.CancelFunc } -const ( - roleTaskd = "taskd" - maxRegRetryTime = 60 - maxWaitTime = 60 - waitGapTime = 1 -) - // Init base manger func (m *BaseManager) Init() error { if err := utils.InitHwLogger("manager.log", context.Background()); err != nil { @@ -96,15 +54,12 @@ func (m *BaseManager) Init() error { m.svcCtx, m.cancelFunc = context.WithCancel(context.Background()) m.MsgHd = application.NewMsgHandler() m.MsgHd.Start(m.svcCtx) - m.BusinessHandler = application.NewBusinessStreamProcessor(m.MsgHd) if err := m.BusinessHandler.Init(); err != nil { hwlog.RunLog.Errorf("business handler init failed, err: %v", err) return err } - go m.registerClusterD(0) - go m.watchProfilingCmdChange() - + m.clusterHandle() hwlog.RunLog.Info("manager init success!") return nil } @@ -147,214 +102,33 @@ func (m *BaseManager) Service(snapshot *storage.SnapShot) error { return nil } -func (m *BaseManager) registerClusterD(retryTime time.Duration) { - if retryTime >= maxRegRetryTime { - hwlog.RunLog.Error("init clusterd connect meet max retry time") - return - } - time.Sleep(retryTime * time.Second) - addr, err := utils.GetClusterdAddr() - if err != nil { - hwlog.RunLog.Errorf("get clusterd address err: %v", err) - return - } - hwlog.RunLog.Infof("get clusterd addr %v", addr) - conn, err := grpc.Dial(addr, grpc.WithTransportCredentials(insecure.NewCredentials())) - if err != nil { - hwlog.RunLog.Errorf("init clusterd connect err: %v", err) - m.registerClusterD(retryTime + 1) - return - } - - go m.subscribeProfiling(conn, 0) - go m.subscribeSwitchNic(conn) -} - -func (m *BaseManager) subscribeSwitchNic(conn *grpc.ClientConn) { - client := pb.NewRecoverClient(conn) - clientInfo := &pb.ClientInfo{ - JobId: m.JobId, - Role: roleTaskd, - } - for { - exit, wTime := m.listenSignal(client, clientInfo, waitGapTime) - if exit { - hwlog.RunLog.Info("taskd exit, stop subscribe clusterd fault info") - break +func (m *BaseManager) clusterHandle() { + foundClusterd := false + foundTaskd := false + for _, info := range m.ClusterInfos { + m.clusterAdaptor(info) + if info.Role == constant.ClusterDRank { + foundClusterd = true } - time.Sleep(time.Duration(wTime) * time.Second) - if wTime > maxWaitTime { - wTime = 1 + if info.Role == constant.TaskDRank { + foundTaskd = true } } -} - -func (m *BaseManager) listenSignal(client pb.RecoverClient, clientInfo *pb.ClientInfo, wTime int) (bool, int) { - stream, err := client.SubscribeNotifySwitch(m.svcCtx, clientInfo) - if err != nil { - hwlog.RunLog.Errorf("register Clusterd notify switch fail, err: %v", err) - return false, wTime + waitGapTime + if !foundClusterd { + m.clusterAdaptor(constant.ClusterInfo{Role: constant.ClusterDRank}) } - for { - select { - case <-m.svcCtx.Done(): - hwlog.RunLog.Info("taskd exit, stop subscribe clusterd fault info") - return true, 0 - case <-stream.Context().Done(): - hwlog.RunLog.Error("server stream abnormal interruption, register again") - return false, wTime + waitGapTime - default: - responseMsg, recvErr := stream.Recv() - if recvErr == io.EOF { - hwlog.RunLog.Info("stream EOF, register again") - return false, waitGapTime - } - if recvErr != nil { - hwlog.RunLog.Error(recvErr) - continue - } - hwlog.RunLog.Infof("receive switch nic info: %v", responseMsg) - globalOps := responseMsg.GetOp() - globalRanks := responseMsg.GetRankID() - m.enqueueSwitchNic(globalRanks, globalOps) - } + if !foundTaskd { + m.clusterAdaptor(constant.ClusterInfo{Role: constant.TaskDRank}) } } -func (m *BaseManager) enqueueSwitchNic(ranks []string, ops []bool) { - rankStr := utils.ObjToString(ranks) - opStr := utils.ObjToString(ops) - msg := map[string]string{ - constant.GlobalRankKey: rankStr, - constant.GlobalOpKey: opStr, - constant.SwitchJobID: m.JobId, - } - message := storage.BaseMessage{ - Header: storage.MsgHeader{ - BizType: "default", - Uuid: uuid.New().String(), - Src: &common.Position{ - Role: constant.ClusterRole, - ServerRank: constant.ClusterDRank, - }, - Timestamp: time.Now(), - }, - Body: storage.MsgBody{ - MsgType: constant.Action, - Code: constant.SwitchNicCode, - Extension: msg, - }, - } - err := m.MsgHd.MsgQueue.Enqueue(message) +func (m *BaseManager) clusterAdaptor(info constant.ClusterInfo) { + clusterAdaptor, err := adaptor.InitAdaptor(m.JobId, m.svcCtx, m.MsgHd.MsgQueue, info) if err != nil { - hwlog.RunLog.Errorf("enqueue switch msg err %v", err) + hwlog.RunLog.Error(err) return } - hwlog.RunLog.Infof("enqueue switch msg %v", msg) -} - -func (m *BaseManager) subscribeProfiling(conn *grpc.ClientConn, retryTime time.Duration) { - m.profilingFromClusterD.Store(false) - if retryTime >= maxRegRetryTime { - hwlog.RunLog.Error("register Cluster profiling meet max retry time") - return - } - time.Sleep(retryTime * time.Second) - traceClient := profiling.NewTrainingDataTraceClient(conn) - stream, err := traceClient.SubscribeDataTraceSwitch(m.svcCtx, &profiling.ProfilingClientInfo{ - JobId: m.JobId, - Role: roleTaskd, - }) - if err != nil { - hwlog.RunLog.Errorf("register Cluster profiling fail, err: %v", err) - go m.subscribeProfiling(conn, retryTime+1) - return - } - m.profilingFromClusterD.Store(true) - for { - select { - case <-m.svcCtx.Done(): - hwlog.RunLog.Info("taskd exit, stop subscribe clusterd fault info") - return - case <-stream.Context().Done(): - hwlog.RunLog.Info("client stream exit, stop subscribe profiling info and re-register") - go m.subscribeProfiling(conn, retryTime+1) - return - default: - responseMsg, recvErr := stream.Recv() - if recvErr != nil { - hwlog.RunLog.Error(recvErr) - } else { - hwlog.RunLog.Infof("receive profiling info: %v", responseMsg) - profilingMsg := responseMsg.GetProfilingSwitch() - // notify framework receive profiling msg - domainSwitch := utils.PfSwitchToPfDomainSwitch(convertProfilingMsg(profilingMsg)) - m.enqueueProfilingSwitch(domainSwitch, constant.ClusterDRank) - } - } - } -} - -func (m *BaseManager) enqueueProfilingSwitch(cmd constant.ProfilingDomainCmd, whichServer string) { - message := storage.BaseMessage{ - Header: storage.MsgHeader{ - BizType: "default", - Uuid: uuid.New().String(), - Src: &common.Position{ - Role: constant.ClusterRole, - ServerRank: whichServer, - }, - Timestamp: time.Now(), - }, - Body: storage.MsgBody{ - MsgType: constant.Action, - Code: utils.ProfilingCmdToBizCode(cmd), - }, - } - err := m.MsgHd.MsgQueue.Enqueue(message) - if err != nil { - hwlog.RunLog.Infof("%s enqueue profiling cmd %v err %v", whichServer, cmd, err) - return - } - hwlog.RunLog.Infof("%s enqueue profiling cmd %v", whichServer, cmd) -} - -func (m *BaseManager) watchProfilingCmdChange() { - hwlog.RunLog.Info("begin watch ProfilingSwitchFilePath...") - ticker := time.NewTicker(time.Second) - defer ticker.Stop() - for { - select { - case <-m.svcCtx.Done(): - hwlog.RunLog.Info("end watch ProfilingSwitchFilePath...") - return - case <-ticker.C: - if m.profilingFromClusterD.Load() { - hwlog.RunLog.Infof("manager register clusterd, donot watch profiling file.") - return - } - m.getProfilingFromFile() - } - } -} - -func (m *BaseManager) getProfilingFromFile() { - profilingSwitch, err := utils.GetProfilingSwitch(constant.ProfilingSwitchFilePath) - if err != nil { - hwlog.RunLog.Errorf("GetProfilingSwitch err: %v", err) - return - } - domainSwitch := utils.PfSwitchToPfDomainSwitch(profilingSwitch) - m.enqueueProfilingSwitch(domainSwitch, constant.TaskDRank) -} - -func convertProfilingMsg(profilingSwitchData *profiling.ProfilingSwitch) constant.ProfilingSwitch { - profilingSwitch := constant.ProfilingSwitch{ - CommunicationOperator: profilingSwitchData.CommunicationOperator, - Step: profilingSwitchData.Step, - SaveCheckpoint: profilingSwitchData.SaveCheckpoint, - FP: profilingSwitchData.FP, - DataLoader: profilingSwitchData.DataLoader, + if err = clusterAdaptor.Handle(); err != nil { + hwlog.RunLog.Error(err) } - return profilingSwitch } diff --git a/component/taskd/taskd/go/framework_backend/manager/manager_test.go b/component/taskd/taskd/go/framework_backend/manager/manager_test.go new file mode 100644 index 0000000000000000000000000000000000000000..1ebf167ee08c6e0bf16421527405aed0e66d7299 --- /dev/null +++ b/component/taskd/taskd/go/framework_backend/manager/manager_test.go @@ -0,0 +1,168 @@ +/* Copyright(C) 2025. Huawei Technologies Co.,Ltd. All rights reserved. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +// Package manager for taskd manager backend +package manager + +import ( + "context" + "fmt" + "sync/atomic" + "testing" + + "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + + "ascend-common/common-utils/hwlog" + "taskd/common/constant" + "taskd/common/utils" + "taskd/framework_backend/manager/application" + "taskd/framework_backend/manager/infrastructure/storage" + "taskd/framework_backend/manager/service/adaptor" +) + +const ( + JobId = "JobId" +) + +// TestMain test main +func TestMain(m *testing.M) { + if err := setup(); err != nil { + return + } + code := m.Run() + fmt.Printf("exit_code = %v\n", code) +} + +func setup() error { + return initLog() +} + +func initLog() error { + logConfig := &hwlog.LogConfig{ + OnlyToStdout: true, + } + if err := hwlog.InitRunLogger(logConfig, context.Background()); err != nil { + fmt.Printf("init hwlog failed, %v\n", err) + return err + } + return nil +} + +func getManager() *BaseManager { + config := constant.Config{ + JobId: JobId, + ClusterInfos: []constant.ClusterInfo{ + { + Role: constant.ClusterDRank, + }, + }, + } + manager := NewTaskDManager(config) + manager.MsgHd = application.NewMsgHandler() + manager.BusinessHandler = application.NewBusinessStreamProcessor(manager.MsgHd) + return manager +} + +func TestNewTaskDManager(t *testing.T) { + manager := getManager() + convey.ShouldEqual(manager.JobId, JobId) +} + +func TestBaseManagerInit(t *testing.T) { + patches := gomonkey.NewPatches() + defer patches.Reset() + patches.ApplyFunc(utils.InitHwLogger, func(string, context.Context) error { + return nil + }) + msgHd := application.NewMsgHandler() + patches.ApplyFunc(application.NewMsgHandler, func() *application.MsgHandler { + return msgHd + }) + calledStart := atomic.Bool{} + patches.ApplyMethod(msgHd, "Start", func(*application.MsgHandler, context.Context) { + calledStart.Store(true) + }) + calledClusterHandle := atomic.Bool{} + manager := getManager() + patches.ApplyPrivateMethod(manager, "clusterHandle", func(*BaseManager) { + calledClusterHandle.Store(true) + }) + err := manager.Init() + convey.ShouldBeNil(err) + convey.ShouldBeTrue(calledClusterHandle.Load()) + convey.ShouldBeTrue(calledStart.Load()) +} + +func TestBaseManagerStart(t *testing.T) { + patches := gomonkey.NewPatches() + defer patches.Reset() + manager := getManager() + calledInit := false + calledProcess := false + patches.ApplyMethod(manager, "Init", func(*BaseManager) error { + calledInit = true + return nil + }).ApplyMethod(manager, "Process", func(*BaseManager) error { + calledProcess = true + return nil + }) + err := manager.Start() + convey.ShouldBeNil(err) + convey.ShouldBeTrue(calledInit) + convey.ShouldBeTrue(calledProcess) +} + +func TestBaseManagerService(t *testing.T) { + manager := getManager() + patches := gomonkey.NewPatches() + defer patches.Reset() + calledAllocateToken := false + calledStreamRun := false + gomonkey.ApplyMethod(manager.BusinessHandler, "AllocateToken", + func(*application.BusinessStreamProcessor, *storage.SnapShot) { + calledAllocateToken = true + }).ApplyMethod(manager.BusinessHandler, "StreamRun", + func(*application.BusinessStreamProcessor) error { + calledStreamRun = true + return nil + }) + err := manager.Service(&storage.SnapShot{}) + convey.ShouldBeNil(err) + convey.ShouldBeTrue(calledAllocateToken) + convey.ShouldBeTrue(calledStreamRun) +} + +type mockHandler struct{} + +func (m *mockHandler) Handle() error { return nil } + +func (m *mockHandler) SendMsg(msg storage.MsgBody) error { return nil } + +func TestBaseManagerClusterHandle(t *testing.T) { + manager := getManager() + patches := gomonkey.NewPatches() + defer patches.Reset() + calledHandle := false + m := &mockHandler{} + patches.ApplyFunc(adaptor.InitAdaptor, func( + string, context.Context, *storage.MsgQueue, constant.ClusterInfo) (adaptor.Adaptor, error) { + return m, nil + }).ApplyMethod(m, "Handle", func(*mockHandler) error { + calledHandle = true + return nil + }) + manager.clusterHandle() + convey.ShouldBeTrue(calledHandle) +} diff --git a/component/taskd/taskd/go/framework_backend/manager/service/adaptor/clusterd.go b/component/taskd/taskd/go/framework_backend/manager/service/adaptor/clusterd.go new file mode 100644 index 0000000000000000000000000000000000000000..3c0161fee550b69154bc039efe00d2ea3c710545 --- /dev/null +++ b/component/taskd/taskd/go/framework_backend/manager/service/adaptor/clusterd.go @@ -0,0 +1,348 @@ +/* Copyright(C) 2025. Huawei Technologies Co.,Ltd. All rights reserved. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +// Package adaptor is to provide cluster adaptor, i.e. clusterd, taskd +package adaptor + +import ( + "context" + "fmt" + "io" + "net" + "sync/atomic" + "time" + + "github.com/google/uuid" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + + "ascend-common/common-utils/hwlog" + "clusterd/pkg/interface/grpc/profiling" + pb "clusterd/pkg/interface/grpc/recover" + "taskd/common/constant" + "taskd/common/utils" + "taskd/framework_backend/manager/infrastructure/storage" + "taskd/toolkit_backend/net/common" +) + +const ( + roleTaskd = "taskd" + maxRegRetryTime = 60 + maxSendRetryTime = 3 + maxWaitTime = 60 + waitGapTime = 1 +) + +func newClusterdAdapator( + jobId string, svcCtx context.Context, queue *storage.MsgQueue, info constant.ClusterInfo) *clusterdAdaptor { + return &clusterdAdaptor{ + baseAdaptor: &baseAdaptor{ + jobId: jobId, + svcCtx: svcCtx, + queue: queue, + clusterInfo: info, + }, + profilingFromClusterD: atomic.Bool{}, + } +} + +// Handle Clusterd +func (m *clusterdAdaptor) Handle() error { + addr, err := m.getAddr() + if err != nil { + return fmt.Errorf("get clusterd addr err %v", err) + } + hwlog.RunLog.Infof("get clusterd addr %v", addr) + go m.registerClusterd(addr) + go m.watchProfilingCmdChange() + return nil +} + +// SendMsg to Clusterd +func (m *clusterdAdaptor) SendMsg(msg storage.MsgBody) error { + if msg.MsgType == constant.SwitchNic { + return m.sendSwitchNicStatusRetry(msg) + } + return fmt.Errorf("cannot handle msg %v", msg) +} + +func (m *clusterdAdaptor) sendSwitchNicStatusRetry(msg storage.MsgBody) error { + var status bool + switch msg.Code { + case 0: + status = false + case 1: + status = true + default: + return fmt.Errorf("invalid status %d of SwitchNic", msg.Code) + } + sendSucc := false + var err error + for retryTime := time.Duration(0); retryTime < maxSendRetryTime; retryTime++ { + time.Sleep(retryTime * time.Second) + addr, err := m.getAddr() + if err != nil { + return fmt.Errorf("send switch nic status err: %v", err) + } + conn, err := grpc.Dial(addr, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + hwlog.RunLog.Errorf("init clusterd connect err: %v", err) + continue + } + client := pb.NewRecoverClient(conn) + _, err = client.ReplySwitchNicResult(context.TODO(), &pb.SwitchResult{Result: status, JobId: m.jobId}) + if err != nil { + hwlog.RunLog.Errorf("reply SwitchNicResult err: %v", err) + continue + } + sendSucc = true + break + } + if !sendSucc { + return fmt.Errorf("reply switchNic result failed, last err: %v", err) + } + return nil +} + +func (m *clusterdAdaptor) registerClusterd(addr string) { + var conn *grpc.ClientConn + var err error + var succ bool + for retryTime := time.Duration(0); retryTime < maxRegRetryTime; retryTime++ { + time.Sleep(retryTime * time.Second) + conn, err = grpc.Dial(addr, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + hwlog.RunLog.Errorf("init clusterd connect err: %v", err) + continue + } + succ = true + break + } + if !succ { + hwlog.RunLog.Error("init clusterd connect meet max retry time") + return + } + go m.subscribeProfiling(conn) + go m.subscribeSwitchNic(conn) +} + +func (m *clusterdAdaptor) subscribeSwitchNic(conn *grpc.ClientConn) { + client := pb.NewRecoverClient(conn) + clientInfo := &pb.ClientInfo{ + JobId: m.jobId, + Role: roleTaskd, + } + for { + exit, wTime := m.listenSignal(client, clientInfo, waitGapTime) + if exit { + hwlog.RunLog.Info("taskd exit, stop subscribe clusterd fault info") + break + } + time.Sleep(time.Duration(wTime) * time.Second) + if wTime > maxWaitTime { + wTime = 1 + } + } +} + +func (m *clusterdAdaptor) listenSignal(client pb.RecoverClient, clientInfo *pb.ClientInfo, wTime int) (bool, int) { + stream, err := client.SubscribeNotifySwitch(m.svcCtx, clientInfo) + if err != nil { + hwlog.RunLog.Errorf("register Clusterd notify switch fail, err: %v", err) + return false, wTime + waitGapTime + } + for { + select { + case <-m.svcCtx.Done(): + hwlog.RunLog.Info("taskd exit, stop subscribe clusterd fault info") + return true, 0 + case <-stream.Context().Done(): + hwlog.RunLog.Error("server stream abnormal interruption, register again") + return false, wTime + waitGapTime + default: + responseMsg, recvErr := stream.Recv() + if recvErr == io.EOF { + hwlog.RunLog.Info("stream EOF, register again") + return false, waitGapTime + } + if recvErr != nil { + hwlog.RunLog.Error(recvErr) + continue + } + hwlog.RunLog.Infof("receive switch nic info: %v", responseMsg) + globalOps := responseMsg.GetOp() + globalRanks := responseMsg.GetRankID() + m.enqueueSwitchNic(globalRanks, globalOps) + } + } +} + +func (m *clusterdAdaptor) enqueueSwitchNic(ranks []string, ops []bool) { + rankStr := utils.ObjToString(ranks) + opStr := utils.ObjToString(ops) + msg := map[string]string{ + constant.GlobalRankKey: rankStr, + constant.GlobalOpKey: opStr, + constant.SwitchJobID: m.jobId, + } + message := storage.BaseMessage{ + Header: storage.MsgHeader{ + BizType: "default", + Uuid: uuid.New().String(), + Src: &common.Position{ + Role: constant.ClusterRole, + ServerRank: constant.ClusterDRank, + }, + Timestamp: time.Now(), + }, + Body: storage.MsgBody{ + MsgType: constant.Action, + Code: constant.SwitchNicCode, + Extension: msg, + }, + } + err := m.queue.Enqueue(message) + if err != nil { + hwlog.RunLog.Errorf("enqueue switch msg err %v", err) + return + } + hwlog.RunLog.Infof("enqueue switch msg %v", msg) +} + +func (m *clusterdAdaptor) subscribeProfiling(conn *grpc.ClientConn) { + var stream profiling.TrainingDataTrace_SubscribeDataTraceSwitchClient + var err error + var succ bool + for retryTime := time.Duration(0); retryTime < maxRegRetryTime; retryTime++ { + m.profilingFromClusterD.Store(false) + time.Sleep(retryTime * time.Second) + traceClient := profiling.NewTrainingDataTraceClient(conn) + stream, err = traceClient.SubscribeDataTraceSwitch(m.svcCtx, &profiling.ProfilingClientInfo{ + JobId: m.jobId, + Role: roleTaskd, + }) + if err != nil { + hwlog.RunLog.Errorf("register Cluster profiling fail, err: %v", err) + continue + } + if !m.recvProfiling(stream) { + break + } + } + if !succ { + hwlog.RunLog.Error("register Cluster profiling meet max retry time") + return + } +} + +func (m *clusterdAdaptor) recvProfiling(stream profiling.TrainingDataTrace_SubscribeDataTraceSwitchClient) bool { + m.profilingFromClusterD.Store(true) + for { + select { + case <-m.svcCtx.Done(): + hwlog.RunLog.Info("taskd exit, stop subscribe clusterd fault info") + return false + case <-stream.Context().Done(): + hwlog.RunLog.Info("client stream exit, stop subscribe profiling info and re-register") + return true + default: + responseMsg, recvErr := stream.Recv() + if recvErr != nil { + hwlog.RunLog.Error(recvErr) + } else { + hwlog.RunLog.Infof("receive profiling info: %v", responseMsg) + profilingMsg := responseMsg.GetProfilingSwitch() + // notify framework receive profiling msg + domainSwitch := utils.PfSwitchToPfDomainSwitch(convertProfilingMsg(profilingMsg)) + m.enqueueProfilingSwitch(domainSwitch, constant.ClusterDRank) + } + } + } +} + +func (m *clusterdAdaptor) enqueueProfilingSwitch(cmd constant.ProfilingDomainCmd, whichServer string) { + message := storage.BaseMessage{ + Header: storage.MsgHeader{ + BizType: "default", + Uuid: uuid.New().String(), + Src: &common.Position{ + Role: constant.ClusterRole, + ServerRank: whichServer, + }, + Timestamp: time.Now(), + }, + Body: storage.MsgBody{ + MsgType: constant.Action, + Code: utils.ProfilingCmdToBizCode(cmd), + }, + } + err := m.queue.Enqueue(message) + if err != nil { + hwlog.RunLog.Infof("%s enqueue profiling cmd %v err %v", whichServer, cmd, err) + return + } + hwlog.RunLog.Infof("%s enqueue profiling cmd %v", whichServer, cmd) +} + +func (m *clusterdAdaptor) watchProfilingCmdChange() { + hwlog.RunLog.Info("begin watch ProfilingSwitchFilePath...") + ticker := time.NewTicker(time.Second) + defer ticker.Stop() + for { + select { + case <-m.svcCtx.Done(): + hwlog.RunLog.Info("end watch ProfilingSwitchFilePath...") + return + case <-ticker.C: + if m.profilingFromClusterD.Load() { + hwlog.RunLog.Infof("manager register clusterd, donot watch profiling file.") + return + } + m.getProfilingFromFile() + } + } +} + +func (m *clusterdAdaptor) getProfilingFromFile() { + profilingSwitch, err := utils.GetProfilingSwitch(constant.ProfilingSwitchFilePath) + if err != nil { + hwlog.RunLog.Errorf("GetProfilingSwitch err: %v", err) + return + } + domainSwitch := utils.PfSwitchToPfDomainSwitch(profilingSwitch) + m.enqueueProfilingSwitch(domainSwitch, constant.TaskDRank) +} + +func (m *clusterdAdaptor) getAddr() (string, error) { + res := m.clusterInfo.Ip + ":" + m.clusterInfo.Port + if m.clusterInfo.Ip != "" && m.clusterInfo.Port != "" && net.ParseIP(m.clusterInfo.Ip) != nil { + return res, nil + } + addr, err := utils.GetClusterdAddr() + if err != nil { + return "", fmt.Errorf("get address from %v and env err: %v", m.clusterInfo, err) + } + return addr, nil +} + +func convertProfilingMsg(profilingSwitchData *profiling.ProfilingSwitch) constant.ProfilingSwitch { + profilingSwitch := constant.ProfilingSwitch{ + CommunicationOperator: profilingSwitchData.CommunicationOperator, + Step: profilingSwitchData.Step, + SaveCheckpoint: profilingSwitchData.SaveCheckpoint, + FP: profilingSwitchData.FP, + DataLoader: profilingSwitchData.DataLoader, + } + return profilingSwitch +} diff --git a/component/taskd/taskd/go/framework_backend/manager/service/adaptor/factory.go b/component/taskd/taskd/go/framework_backend/manager/service/adaptor/factory.go new file mode 100644 index 0000000000000000000000000000000000000000..babe4c66e7e916362a63ac0693991d29a59c2e96 --- /dev/null +++ b/component/taskd/taskd/go/framework_backend/manager/service/adaptor/factory.go @@ -0,0 +1,59 @@ +/* Copyright(C) 2025. Huawei Technologies Co.,Ltd. All rights reserved. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +// Package adaptor is to provide cluster adaptor, i.e. clusterd, taskd +package adaptor + +import ( + "context" + "fmt" + + "ascend-common/common-utils/hwlog" + "taskd/common/constant" + "taskd/framework_backend/manager/infrastructure" + "taskd/framework_backend/manager/infrastructure/storage" +) + +var adaptorMap map[string]Adaptor + +// InitAdaptor return handler according to Role of ClusterInfo +func InitAdaptor( + jobId string, svcCtx context.Context, queue *storage.MsgQueue, info constant.ClusterInfo) (Adaptor, error) { + if adaptorMap == nil { + adaptorMap = make(map[string]Adaptor) + } + if info.Role == constant.ClusterDRank { + clusterd := newClusterdAdapator(jobId, svcCtx, queue, info) + adaptorMap[constant.ClusterDRank] = clusterd + return clusterd, nil + } + if info.Role == constant.TaskDRank { + taskd := NewTaskdAdaptor(jobId, svcCtx, queue, info) + adaptorMap[constant.TaskDRank] = taskd + return taskd, nil + } + return nil, fmt.Errorf("init handler for %v failed", info) +} + +// DistributeMsg to cluster according to receiver +func DistributeMsg(msg infrastructure.Msg, receiver string) bool { + if clusterAdaptor, found := adaptorMap[receiver]; found { + go func() { + err := clusterAdaptor.SendMsg(msg.Body) + hwlog.RunLog.Error(err) + }() + return true + } + return false +} diff --git a/component/taskd/taskd/go/framework_backend/manager/service/adaptor/taskd.go b/component/taskd/taskd/go/framework_backend/manager/service/adaptor/taskd.go new file mode 100644 index 0000000000000000000000000000000000000000..bff0fbefbd61843098a5813f2410f5416ab844ca --- /dev/null +++ b/component/taskd/taskd/go/framework_backend/manager/service/adaptor/taskd.go @@ -0,0 +1,57 @@ +/* Copyright(C) 2025. Huawei Technologies Co.,Ltd. All rights reserved. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +// Package adaptor is to provide cluster adaptor, i.e. clusterd, taskd +package adaptor + +import ( + "context" + "fmt" + + "github.com/google/uuid" + + "taskd/common/constant" + "taskd/framework_backend/manager/infrastructure/storage" + "taskd/toolkit_backend/net/common" +) + +func NewTaskdAdaptor( + jobId string, svcCtx context.Context, queue *storage.MsgQueue, info constant.ClusterInfo) *taskdAdaptor { + return &taskdAdaptor{ + baseAdaptor: &baseAdaptor{ + jobId: jobId, + svcCtx: svcCtx, + queue: queue, + clusterInfo: info, + }, + } +} + +func (t *taskdAdaptor) Handle() error { + return nil +} + +func (t *taskdAdaptor) SendMsg(msg storage.MsgBody) error { + newMsg := t.queue.NewMsg(uuid.New().String(), constant.DefaultBizType, + &common.Position{ + Role: common.MgrRole, + ServerRank: "0", + ProcessRank: "-1", + }, msg) + err := t.queue.Enqueue(newMsg) + if err != nil { + return fmt.Errorf("send msg to taskd err: %v", err) + } + return nil +} diff --git a/component/taskd/taskd/go/framework_backend/manager/service/adaptor/type.go b/component/taskd/taskd/go/framework_backend/manager/service/adaptor/type.go new file mode 100644 index 0000000000000000000000000000000000000000..2f4ef2f80f4fc5c7009395fd48f23f17fcaf9785 --- /dev/null +++ b/component/taskd/taskd/go/framework_backend/manager/service/adaptor/type.go @@ -0,0 +1,46 @@ +/* Copyright(C) 2025. Huawei Technologies Co.,Ltd. All rights reserved. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +// Package adaptor is to provide cluster adaptor, i.e. clusterd, taskd +package adaptor + +import ( + "context" + "sync/atomic" + + "taskd/common/constant" + "taskd/framework_backend/manager/infrastructure/storage" +) + +// Adaptor interface of clusters +type Adaptor interface { + Handle() error + SendMsg(msg storage.MsgBody) error +} + +type baseAdaptor struct { + jobId string + svcCtx context.Context + queue *storage.MsgQueue + clusterInfo constant.ClusterInfo +} + +type clusterdAdaptor struct { + *baseAdaptor + profilingFromClusterD atomic.Bool +} + +type taskdAdaptor struct { + *baseAdaptor +}