diff --git a/component/taskd/taskd/go/common/constant/const.go b/component/taskd/taskd/go/common/constant/const.go index 799a81ffe96d6e0a3402a4b5f3609abb5c5aada4..936e753fc29ccc4e9362db32527a14562d0e0a80 100644 --- a/component/taskd/taskd/go/common/constant/const.go +++ b/component/taskd/taskd/go/common/constant/const.go @@ -125,6 +125,7 @@ const ( FaultRankCode = 202 ExitAgentCode = 203 SwitchNicCode = 204 + StartAgentWorkerCode = 205 ProfilingAllCloseCmdCode = 700 ProfilingDefaultDomainOnCode = 710 ProfilingCommDomainOnCode = 701 @@ -164,9 +165,10 @@ const ( // All cluster info type must be defined here const ( - ClusterRole = "Cluster" - ClusterDRank = "ClusterD" - TaskDRank = "TaskD" + ClusterRole = "Cluster" + ClusterDRank = "ClusterD" + TaskDRank = "TaskD" + ControllerRole = "Controller" ) // All cluster command must be defined here @@ -299,3 +301,27 @@ const ( // LocalProxyEnableOn local proxy enable value LocalProxyEnableOn = "on" ) + +const ( + // JobReschedulingPluginName name of job rescheduling plugin + JobReschedulingPluginName = "JobReschedulingPlugin" + // JobReschedulingStreamName name of job rescheduling stream + JobReschedulingStreamName = "JobReschedulingStream" + // SingalKillMaster singal kill master + SingalKillMaster = "killMaster" + // RestartController restart controller + RestartController = "restart_controller" + // DestroyController destroy controller + DestroyController = "destroy_controller" + // Actions actions + Actions = "actions" + // RecoverStreamName name of recover stream + RecoverStreamName = "RecoverStream" +) + +const ( + // PodReschedulingPluginName name of pod rescheduling plugin + PodReschedulingPluginName = "PodReschedulingPlugin" + // PodReschedulingStreamName name of pod rescheduling stream + PodReschedulingStreamName = "PodReschedulingStream" +) diff --git a/component/taskd/taskd/go/framework_backend/manager/application/msghandler.go b/component/taskd/taskd/go/framework_backend/manager/application/msghandler.go index ace8e2d14ccf18e6db8385009e6f7290b08c2269..dda5c1e2f0561b89baad9c28ef90dc99ee4ee385 100644 --- a/component/taskd/taskd/go/framework_backend/manager/application/msghandler.go +++ b/component/taskd/taskd/go/framework_backend/manager/application/msghandler.go @@ -19,6 +19,7 @@ import ( "context" "errors" "os" + "strconv" "sync" "github.com/google/uuid" @@ -73,6 +74,10 @@ func NewMsgHandler() *MsgHandler { AllStatus: map[string]string{}, RWMutex: sync.RWMutex{}, }, + MgrInfos: &storage.MgrInfo{ + Status: map[string]string{}, + RWMutex: sync.RWMutex{}, + }, }, RWMutex: sync.RWMutex{}, }, @@ -157,10 +162,55 @@ func (mhd *MsgHandler) processOne(ctx context.Context) { if err != nil { hwlog.RunLog.Error(err) } + if msg.Header.Src.Role == common.AgentRole && msg.Body.Code == constant.RestartTimeCode { + mhd.responseAgentRestartTimes(msg) + } } } } +func (mhd *MsgHandler) responseAgentRestartTimes(msg storage.BaseMessage) { + mgrInfo, err := mhd.DataPool.GetMgr() + if mgrInfo == nil { + hwlog.RunLog.Errorf("responseAgentRestartTimes: failed to get manager info, mgrInfo is nil") + return + } + if err != nil { + hwlog.RunLog.Errorf("responseAgentRestartTimes: failed to get manager info, err: %v", err) + return + } + mgrRestartTimes := 0 + if restartTimeStr, exists := mgrInfo.Status[constant.ReportRestartTime]; exists && restartTimeStr != "" { + var parseErr error + mgrRestartTimes, parseErr = strconv.Atoi(restartTimeStr) + if parseErr != nil { + hwlog.RunLog.Errorf("responseAgentRestartTimes: failed to parse manager restart times '%s', err: %v", restartTimeStr, parseErr) + } + } else { + hwlog.RunLog.Infof("responseAgentRestartTimes: manager restart time not found or empty, using default value 0") + } + agentRestartTimes := 0 + if msg.Body.Message != "" { + var parseErr error + agentRestartTimes, parseErr = strconv.Atoi(msg.Body.Message) + if parseErr != nil { + hwlog.RunLog.Errorf("responseAgentRestartTimes: failed to parse agent restart times '%s', err: %v", msg.Body.Message, parseErr) + } + } + restartTimes := mgrRestartTimes + if mgrRestartTimes == 0 { + restartTimes = agentRestartTimes + hwlog.RunLog.Debugf("responseAgentRestartTimes: using agent restart times %d as manager restart times is 0", agentRestartTimes) + } + msgBody := storage.MsgBody{ + MsgType: constant.Action, + Code: constant.StartAgentWorkerCode, + Message: strconv.Itoa(restartTimes), + } + hwlog.RunLog.Infof("responseAgentRestartTimes: sending response with restart times %d", restartTimes) + mhd.SendMsgUseGrpc(msg.Header.BizType, utils.ObjToString(msgBody), msg.Header.Src) +} + func (mhd *MsgHandler) receiver(tool *net.NetInstance, ctx context.Context) { for i := 0; i < constant.RequestChanNum; i++ { go mhd.receiveGoroutine(tool, ctx) diff --git a/component/taskd/taskd/go/framework_backend/manager/infrastructure/storage/datapool.go b/component/taskd/taskd/go/framework_backend/manager/infrastructure/storage/datapool.go index 549387243d041633ad787e8873090979abd76191..1969fb598c80fc3031766f552abffda81ce4b3e6 100644 --- a/component/taskd/taskd/go/framework_backend/manager/infrastructure/storage/datapool.go +++ b/component/taskd/taskd/go/framework_backend/manager/infrastructure/storage/datapool.go @@ -36,6 +36,7 @@ type SnapShot struct { AgentInfos *AgentInfos WorkerInfos *WorkerInfos ClusterInfos *ClusterInfos + MgrInfos *MgrInfo } // MsgQueue the queue store message @@ -125,6 +126,15 @@ func (d *DataPool) UpdateCluster(clusterName string, clusterInfo *ClusterInfo) e return err } +// UpdateMgr update mgr info in the data pool +func (d *DataPool) UpdateMgr(mgrInfo *MgrInfo) error { + if d == nil || d.Snapshot == nil || d.Snapshot.MgrInfos == nil || d.Snapshot.MgrInfos.Status == nil { + return fmt.Errorf("mgr is not initialized") + } + err := d.Snapshot.MgrInfos.updateMgr(mgrInfo) + return err +} + // GetAgent return agent info about agent name func (d *DataPool) GetAgent(agentName string) (*AgentInfo, error) { return d.Snapshot.AgentInfos.getAgent(agentName) @@ -140,6 +150,11 @@ func (d *DataPool) GetCluster(clusterName string) (*ClusterInfo, error) { return d.Snapshot.ClusterInfos.getCluster(clusterName) } +// GetMgr return mgr info about mgr name +func (d *DataPool) GetMgr() (*MgrInfo, error) { + return d.Snapshot.MgrInfos.getMgrInfo() +} + // GetPos return worker or agent position func (d *DataPool) GetPos(infoType, name string) (*common.Position, error) { switch infoType { diff --git a/component/taskd/taskd/go/framework_backend/manager/infrastructure/storage/mgr_infos.go b/component/taskd/taskd/go/framework_backend/manager/infrastructure/storage/mgr_infos.go new file mode 100644 index 0000000000000000000000000000000000000000..9ea444d98452b2ab2f829e07a65f12f5d6aaf736 --- /dev/null +++ b/component/taskd/taskd/go/framework_backend/manager/infrastructure/storage/mgr_infos.go @@ -0,0 +1,41 @@ +/* Copyright(C) 2025. Huawei Technologies Co.,Ltd. All rights reserved. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +// Package storage for taskd manager backend data type +package storage + +import ( + "sync" +) + +type MgrInfo struct { + Status map[string]string + RWMutex sync.RWMutex +} + +func (m *MgrInfo) getMgrInfo() (*MgrInfo, error) { + m.RWMutex.RLock() + defer m.RWMutex.RUnlock() + return &MgrInfo{ + Status: m.Status, + RWMutex: sync.RWMutex{}, + }, nil +} + +func (m *MgrInfo) updateMgr(newMgr *MgrInfo) error { + m.RWMutex.Lock() + defer m.RWMutex.Unlock() + m.Status = newMgr.Status + return nil +} diff --git a/component/taskd/taskd/go/framework_backend/manager/infrastructure/storage/type.go b/component/taskd/taskd/go/framework_backend/manager/infrastructure/storage/type.go index 4cb936b17f81f072d48fd21a2a214352ad95aa44..0aa57bc0fabdfd71fc12174233052a68277fc062 100644 --- a/component/taskd/taskd/go/framework_backend/manager/infrastructure/storage/type.go +++ b/component/taskd/taskd/go/framework_backend/manager/infrastructure/storage/type.go @@ -42,3 +42,9 @@ type MsgBody struct { Message string `json:"message"` Extension map[string]string `json:"extension"` } + +// AgentReportInfo agent report info +type AgentReportInfo struct { + FaultRanks []int `json:"fault_ranks"` + RestartTime int `json:"restart_time"` +} diff --git a/component/taskd/taskd/go/framework_backend/manager/plugins/job_rescheduling/job_rescheduling.go b/component/taskd/taskd/go/framework_backend/manager/plugins/job_rescheduling/job_rescheduling.go new file mode 100644 index 0000000000000000000000000000000000000000..9847e4bd95018537e6aeba586ad289b240ad434d --- /dev/null +++ b/component/taskd/taskd/go/framework_backend/manager/plugins/job_rescheduling/job_rescheduling.go @@ -0,0 +1,182 @@ +/* Copyright(C) 2025. Huawei Technologies Co.,Ltd. All rights reserved. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +// Package jobrescheduling for taskd manager plugin +package jobrescheduling + +import ( + "ascend-common/common-utils/hwlog" + "taskd/common/constant" + "taskd/common/utils" + "taskd/framework_backend/manager/infrastructure" + "taskd/framework_backend/manager/infrastructure/storage" + "taskd/toolkit_backend/net/common" +) + +// JobReschedulingPlugin job rescheduling plugin +type JobReschedulingPlugin struct { + pullMsgs []infrastructure.Msg + faultOccur bool + processStatus string + agentStatus map[int]bool + killMaster bool +} + +var ( + agent0ExitMsg = infrastructure.Msg{ + Receiver: []string{common.AgentRole + "0"}, + Body: storage.MsgBody{ + MsgType: constant.Action, + Code: constant.ExitAgentCode, + }, + } +) + +// NewJobReschedulingPlugin new job rescheduling plugin +func NewJobReschedulingPlugin() infrastructure.ManagerPlugin { + return &JobReschedulingPlugin{ + pullMsgs: make([]infrastructure.Msg, 0), + faultOccur: false, + processStatus: "", + agentStatus: make(map[int]bool), + killMaster: false, + } +} + +// Name name of job rescheduling plugin +func (job *JobReschedulingPlugin) Name() string { + return constant.JobReschedulingPluginName +} + +// Handle handle job rescheduling plugin +func (job *JobReschedulingPlugin) Handle() (infrastructure.HandleResult, error) { + job.processStatus = constant.HandleStageProcess + if job.killMaster { + job.pullMsgs = append(job.pullMsgs, infrastructure.Msg{ + Receiver: []string{constant.ControllerRole}, + Body: storage.MsgBody{ + MsgType: constant.Action, + Code: 0, + Extension: map[string]string{ + constant.Actions: utils.ObjToString([]string{constant.DestroyController}), + }, + }, + }) + job.pullMsgs = append(job.pullMsgs, agent0ExitMsg) + return infrastructure.HandleResult{Stage: constant.HandleStageProcess}, nil + } + if !job.faultOccur { + hwlog.RunLog.Info("JobReschedulingPlugin not fault occur") + job.resetPluginInfo() + return infrastructure.HandleResult{ + Stage: constant.HandleStageFinal, + }, nil + } + + if value, ok := job.agentStatus[0]; ok && value == true { + hwlog.RunLog.Info("JobReschedulingPlugin agent 0 exit") + job.resetPluginInfo() + return infrastructure.HandleResult{ + Stage: constant.HandleStageFinal, + }, nil + } + hwlog.RunLog.Info("JobReschedulingPlugin handle fault") + job.pullMsgs = append(job.pullMsgs, agent0ExitMsg) + return infrastructure.HandleResult{Stage: constant.HandleStageProcess}, nil +} + +// Handle handle job rescheduling plugin +func (job *JobReschedulingPlugin) PullMsg() ([]infrastructure.Msg, error) { + msgs := job.pullMsgs + job.pullMsgs = make([]infrastructure.Msg, 0) + return msgs, nil +} + +// Predicate predicate job rescheduling plugin +func (job *JobReschedulingPlugin) Predicate(shot storage.SnapShot) (infrastructure.PredicateResult, error) { + if job.processStatus != "" { + hwlog.RunLog.Infof("JobReschedulingPlugin Predicate processStatus:%v", job.processStatus) + job.updatePluginInfo(shot) + return infrastructure.PredicateResult{PluginName: job.Name(), + CandidateStatus: constant.CandidateStatus, + PredicateStream: map[string]string{ + constant.RecoverStreamName: "", + }}, nil + } + job.resetPluginInfo() + clusterInfo, ok := shot.ClusterInfos.Clusters[constant.ClusterRole] + if ok { + if clusterInfo.Command[constant.SingalKillMaster] != "" { + job.killMaster = true + return infrastructure.PredicateResult{PluginName: job.Name(), + CandidateStatus: constant.CandidateStatus, + PredicateStream: map[string]string{ + constant.RecoverStreamName: "", + }}, nil + } + } + + for _, agent := range shot.AgentInfos.Agents { + if agent.Status[constant.ReportFaultRank] != "" { + job.faultOccur = true + hwlog.RunLog.Infof("JobReschedulingPlugin candidate token, info:%v", agent.Status[constant.ReportFaultRank]) + return infrastructure.PredicateResult{PluginName: job.Name(), + CandidateStatus: constant.CandidateStatus, + PredicateStream: map[string]string{ + constant.RecoverStreamName: "", + }}, nil + } + } + hwlog.RunLog.Info("JobReschedulingPlugin not fault occur") + return infrastructure.PredicateResult{ + PluginName: job.Name(), CandidateStatus: constant.UnselectStatus, PredicateStream: nil}, nil +} + +// Release release job rescheduling plugin +func (job *JobReschedulingPlugin) Release() error { + return nil +} + +func (job *JobReschedulingPlugin) resetPluginInfo() { + job.pullMsgs = make([]infrastructure.Msg, 0) + job.faultOccur = false + job.processStatus = "" + job.agentStatus = make(map[int]bool) + job.killMaster = false +} + +func (job *JobReschedulingPlugin) updatePluginInfo(shot storage.SnapShot) { + agenInfo, ok := shot.AgentInfos.Agents[common.AgentRole+"0"] + if !ok { + hwlog.RunLog.Errorf("JobReschedulingPlugin updatePluginInfo agent 0 not exist") + job.resetPluginInfo() + return + } + if agenInfo.Status[constant.Exit] != "" { + job.agentStatus[0] = true + } + clusterInfo, ok := shot.ClusterInfos.Clusters[constant.ClusterRole] + if ok { + if clusterInfo.Command[constant.SingalKillMaster] != "" { + job.killMaster = true + } + } + + for _, agent := range shot.AgentInfos.Agents { + if agent.Status[constant.ReportFaultRank] != "" { + job.faultOccur = true + break + } + } +} diff --git a/component/taskd/taskd/go/framework_backend/manager/plugins/podrescheduling/pod_rescheduling.go b/component/taskd/taskd/go/framework_backend/manager/plugins/podrescheduling/pod_rescheduling.go new file mode 100644 index 0000000000000000000000000000000000000000..29231c014600218206b95dfe43829b38ea38cb63 --- /dev/null +++ b/component/taskd/taskd/go/framework_backend/manager/plugins/podrescheduling/pod_rescheduling.go @@ -0,0 +1,191 @@ +/* Copyright(C) 2025. Huawei Technologies Co.,Ltd. All rights reserved. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +// Package jobrescheduling for taskd manager plugin +package podrescheduling + +import ( + "ascend-common/common-utils/hwlog" + "strconv" + "taskd/common/constant" + "taskd/common/utils" + "taskd/framework_backend/manager/infrastructure" + "taskd/framework_backend/manager/infrastructure/storage" + "taskd/toolkit_backend/net/common" +) + +// PodReschedulingPlugin pod rescheduling plugin +type PodReschedulingPlugin struct { + pullMsgs []infrastructure.Msg + processStatus string + faultAgentStatus map[string]bool + faultOccur bool + restartTimes int + exitNum int +} + +// NewPodReschedulingPlugin new pod rescheduling plugin +func NewPodReschedulingPlugin() infrastructure.ManagerPlugin { + return &PodReschedulingPlugin{ + pullMsgs: make([]infrastructure.Msg, 0), + processStatus: "", + faultAgentStatus: make(map[string]bool), + restartTimes: -1, + exitNum: 0, + faultOccur: false, + } +} + +// Name name of pod rescheduling plugin +func (pod *PodReschedulingPlugin) Name() string { + return constant.PodReschedulingPluginName +} + +// Handle handle pod rescheduling plugin +func (pod *PodReschedulingPlugin) Handle() (infrastructure.HandleResult, error) { + pod.processStatus = constant.HandleStageProcess + exitReciver := []string{} + restartReceiver := []string{} + for agentName, faultStatus := range pod.faultAgentStatus { + if faultStatus { + exitReciver = append(exitReciver, agentName) + } else { + restartReceiver = append(restartReceiver, agentName) + } + } + if len(exitReciver) == 0 { + pod.resetPluginInfo() + return infrastructure.HandleResult{Stage: constant.HandleStageFinal}, nil + } + if pod.exitNum != 0 { + hwlog.RunLog.Infof("pod rescheduling plugin handle, exit num: %d", pod.exitNum) + return infrastructure.HandleResult{Stage: constant.HandleStageProcess}, nil + } + pod.restartTimes -= 1 + hwlog.RunLog.Infof("pod rescheduling plugin handle, restart times: %d", pod.restartTimes) + hwlog.RunLog.Infof("pod rescheduling plugin handle, exit receiver: %v", exitReciver) + hwlog.RunLog.Infof("pod rescheduling plugin handle, restart receiver: %v", restartReceiver) + pod.addHandleMsgs(exitReciver, restartReceiver) + pod.exitNum = len(exitReciver) + + return infrastructure.HandleResult{Stage: constant.HandleStageProcess}, nil +} + +func (pod *PodReschedulingPlugin) addHandleMsgs(exitReciver []string, restartReceiver []string) { + pod.pullMsgs = append(pod.pullMsgs, infrastructure.Msg{ + Receiver: []string{common.MgrRole}, + Body: storage.MsgBody{ + MsgType: constant.Action, + Code: constant.RestartTimeCode, + Message: strconv.Itoa(pod.restartTimes), + }, + }) + pod.pullMsgs = append(pod.pullMsgs, infrastructure.Msg{ + Receiver: exitReciver, + Body: storage.MsgBody{ + MsgType: constant.Action, + Code: constant.ExitAgentCode, + Extension: map[string]string{}, + }, + }) + pod.pullMsgs = append(pod.pullMsgs, infrastructure.Msg{ + Receiver: []string{constant.ControllerRole}, + Body: storage.MsgBody{ + MsgType: constant.Action, + Code: 0, + Extension: map[string]string{ + constant.Actions: utils.ObjToString([]string{constant.RestartController}), + }, + }, + }) + + pod.pullMsgs = append(pod.pullMsgs, infrastructure.Msg{ + Receiver: restartReceiver, + Body: storage.MsgBody{ + MsgType: constant.Action, + Code: 206, // restart + Message: strconv.Itoa(pod.restartTimes), + }, + }) +} + +// Handle handle pod rescheduling plugin +func (pod *PodReschedulingPlugin) PullMsg() ([]infrastructure.Msg, error) { + msgs := pod.pullMsgs + pod.pullMsgs = make([]infrastructure.Msg, 0) + return msgs, nil +} + +// Predicate predicate job rescheduling plugin +func (pod *PodReschedulingPlugin) Predicate(shot storage.SnapShot) (infrastructure.PredicateResult, error) { + if pod.processStatus != "" { + pod.updatePluginInfo(shot) + return infrastructure.PredicateResult{PluginName: pod.Name(), + CandidateStatus: constant.CandidateStatus, + PredicateStream: map[string]string{ + constant.RecoverStreamName: "", + }}, nil + } + pod.resetPluginInfo() + for agentName, agentInfo := range shot.AgentInfos.Agents { + pod.faultAgentStatus[agentName] = false + if agentName == common.AgentRole+"0" && agentInfo.Status[constant.ReportFaultRank] != "" { + hwlog.RunLog.Info("agent 0 fault, pod rescheduling plugin unselect") + return infrastructure.PredicateResult{ + PluginName: pod.Name(), CandidateStatus: constant.UnselectStatus, PredicateStream: nil}, nil + } + if agentInfo.Status[constant.ReportRestartTime] != "" && pod.restartTimes == -1 { + hwlog.RunLog.Infof("pod rescheduling first set plugin restart times: %v", agentInfo.Status[constant.ReportRestartTime]) + pod.restartTimes, _ = strconv.Atoi(agentInfo.Status[constant.ReportRestartTime]) + } + if agentInfo.Status[constant.ReportFaultRank] != "" { + pod.faultAgentStatus[agentName] = true + pod.faultOccur = true + } + } + hwlog.RunLog.Debugf("pod rescheduling plugin predicate, fault agent status: %v", pod.faultAgentStatus) + if pod.faultOccur { + return infrastructure.PredicateResult{PluginName: pod.Name(), + CandidateStatus: constant.CandidateStatus, + PredicateStream: map[string]string{ + constant.RecoverStreamName: "", + }}, nil + } + + return infrastructure.PredicateResult{ + PluginName: pod.Name(), CandidateStatus: constant.UnselectStatus, PredicateStream: nil}, nil +} + +// Release release pod rescheduling plugin +func (pod *PodReschedulingPlugin) Release() error { + return nil +} + +func (pod *PodReschedulingPlugin) resetPluginInfo() { + pod.pullMsgs = make([]infrastructure.Msg, 0) + pod.processStatus = "" + pod.faultAgentStatus = make(map[string]bool) + pod.exitNum = 0 + pod.faultOccur = false +} + +func (pod *PodReschedulingPlugin) updatePluginInfo(shot storage.SnapShot) { + for agentName, agentInfo := range shot.AgentInfos.Agents { + if agentInfo.Status[constant.ReportFaultRank] != "" { + pod.faultAgentStatus[agentName] = true + } else { + pod.faultAgentStatus[agentName] = false + } + } +} diff --git a/component/taskd/taskd/go/framework_backend/manager/service/agent.go b/component/taskd/taskd/go/framework_backend/manager/service/agent.go index a1f22f443109fe5e3c85aa2be2d495c79ac72b0b..997a5f007248c6aa60d3f1911963fae9626fd7e4 100644 --- a/component/taskd/taskd/go/framework_backend/manager/service/agent.go +++ b/component/taskd/taskd/go/framework_backend/manager/service/agent.go @@ -20,6 +20,7 @@ import ( "sync" "time" + "ascend-common/common-utils/hwlog" "taskd/common/constant" "taskd/framework_backend/manager/infrastructure/storage" ) @@ -35,6 +36,7 @@ func (mpc *MsgProcessor) agentHandler(dataPool *storage.DataPool, data storage.B } switch data.Body.MsgType { case constant.STATUS: + hwlog.RunLog.Infof("agent status message, data: %v", data) err = mpc.agentStatus(dataPool, data, agentName, agentInfo) case constant.KeepAlive: agentInfo.HeartBeat = time.Now() @@ -62,6 +64,7 @@ func (mpc *MsgProcessor) agentStatus(dataPool *storage.DataPool, data storage.Ba agentInfo *storage.AgentInfo) error { switch data.Body.Code { case constant.RestartTimeCode: + hwlog.RunLog.Infof("agent restart time message, data: %v", data) agentInfo.Status[constant.ReportRestartTime] = data.Body.Message case constant.FaultRankCode: agentInfo.Status[constant.ReportFaultRank] = data.Body.Message diff --git a/component/taskd/taskd/go/framework_backend/manager/service/manager_handler.go b/component/taskd/taskd/go/framework_backend/manager/service/manager_handler.go index 1412125173c9b22eec4c508493998c72fcf776eb..4c36e68e97838bae5d6609211d15db0d11566086 100644 --- a/component/taskd/taskd/go/framework_backend/manager/service/manager_handler.go +++ b/component/taskd/taskd/go/framework_backend/manager/service/manager_handler.go @@ -16,9 +16,25 @@ package service import ( + "fmt" + "taskd/common/constant" "taskd/framework_backend/manager/infrastructure/storage" ) func (mpc *MsgProcessor) managerHandler(dataPool *storage.DataPool, msg storage.BaseMessage) error { - return nil + mgrInfo, err := dataPool.GetMgr() + if err != nil { + return err + } + switch msg.Body.MsgType { + case constant.Action: + if msg.Body.Code == constant.RestartTimeCode { + mgrInfo.Status[constant.ReportRestartTime] = msg.Body.Message + return nil + } + default: + return fmt.Errorf("unknown message type: %v", msg.Body.MsgType) + } + err = dataPool.UpdateMgr(mgrInfo) + return err } diff --git a/component/taskd/taskd/go/framework_backend/manager/service/plugin_handler.go b/component/taskd/taskd/go/framework_backend/manager/service/plugin_handler.go index 59f48f146e14a0dd0475cfbe1289cf7583577a8c..b2df98070d25cba7eb36822117708f656fff2c22 100644 --- a/component/taskd/taskd/go/framework_backend/manager/service/plugin_handler.go +++ b/component/taskd/taskd/go/framework_backend/manager/service/plugin_handler.go @@ -22,7 +22,9 @@ import ( "taskd/framework_backend/manager/infrastructure" "taskd/framework_backend/manager/infrastructure/storage" "taskd/framework_backend/manager/plugins/faultdig" + jobrescheduling "taskd/framework_backend/manager/plugins/job_rescheduling" "taskd/framework_backend/manager/plugins/om" + "taskd/framework_backend/manager/plugins/podrescheduling" ) // PluginHandlerInterface define the interface of plugin handler @@ -52,6 +54,17 @@ func (p *PluginHandler) Init() error { hwlog.RunLog.Errorf("register plugin %s failed!", omPlugin.Name()) return fmt.Errorf("register plugin %s failed", omPlugin.Name()) } + jobReschedulingPlugin := jobrescheduling.NewJobReschedulingPlugin() + if err := p.Register(jobReschedulingPlugin.Name(), jobReschedulingPlugin); err != nil { + hwlog.RunLog.Errorf("register plugin %s failed!", jobReschedulingPlugin.Name()) + return fmt.Errorf("register plugin %s failed", jobReschedulingPlugin.Name()) + } + podReschedulingPlugin := podrescheduling.NewPodReschedulingPlugin() + if err := p.Register(podReschedulingPlugin.Name(), podReschedulingPlugin); err != nil { + hwlog.RunLog.Errorf("register plugin %s failed!", podReschedulingPlugin.Name()) + return fmt.Errorf("register plugin %s failed", podReschedulingPlugin.Name()) + } + return nil } diff --git a/component/taskd/taskd/go/framework_backend/manager/service/stream_handler.go b/component/taskd/taskd/go/framework_backend/manager/service/stream_handler.go index 51d7c4725171719a35bf8a87d61809b96e711a57..4981d7b32033a8781c730ddad181fe78640be476 100644 --- a/component/taskd/taskd/go/framework_backend/manager/service/stream_handler.go +++ b/component/taskd/taskd/go/framework_backend/manager/service/stream_handler.go @@ -66,6 +66,14 @@ func (s *StreamHandler) Init() error { OmStream.GetName()) return err } + RecoverStream := infrastructure.NewStream(constant.RecoverStreamName, + map[string]int{constant.JobReschedulingPluginName: 4, constant.PodReschedulingPluginName: 3}) + if err := s.SetStream(RecoverStream); err != nil { + hwlog.RunLog.Errorf("init stream handler failed: set stream %s failed", + RecoverStream.GetName()) + return err + } + return nil } diff --git a/component/taskd/taskd/python/framework/agent/base_agent/agent_network.py b/component/taskd/taskd/python/framework/agent/base_agent/agent_network.py index f4b9571aadcfe64dd8960c41a5e675e9182d0a6c..8273742e41b2416ce346d405927465b80e383a6d 100644 --- a/component/taskd/taskd/python/framework/agent/base_agent/agent_network.py +++ b/component/taskd/taskd/python/framework/agent/base_agent/agent_network.py @@ -27,6 +27,7 @@ from taskd.python.utils.log import run_log from taskd.python.cython_api import cython_api from taskd.python.framework.common.type import MsgBody, MessageInfo, Position, DEFAULT_BIZTYPE from taskd.python.toolkit.constants.constants import SEND_RETRY_TIMES +from taskd.python.toolkit.constants import constants class AgentMessageManager(): @@ -78,12 +79,13 @@ class AgentMessageManager(): run_log.info(f"agent register: {msg}") self.send_message(msg) - def send_message(self, message: MessageInfo): + def send_message(self, message: MessageInfo, code: int = 0): """ Send message to taskd manager. """ run_log.debug(f"agent send message: {message}") msg_json = json.dumps(asdict(message)).encode('utf-8') + self.lib.DestroyNetTool.argtypes = [ctypes.c_void_p] send_times = 0 self.lib.SyncSendMessage.argtypes = [ctypes.c_void_p, ctypes.c_char_p] self.lib.SyncSendMessage.restype = ctypes.c_int @@ -94,6 +96,9 @@ class AgentMessageManager(): result = self.lib.SyncSendMessage(self._network_instance, msg_json) if result == 0: run_log.info(f"agent send message success, msg: {message.uuid}") + if code == constants.EXITAGENTCODE: + run_log.info(f"agent send exit message, msg: {message.uuid}") + self.lib.DestroyNetTool(self._network_instance) break run_log.warning(f"agent send message failed, result: {result}") send_times += 1 @@ -106,6 +111,8 @@ class AgentMessageManager(): while True: self.lib.ReceiveMessageC.argtypes = [ctypes.c_void_p] self.lib.ReceiveMessageC.restype = ctypes.c_void_p + self.lib.FreeCMemory.argtypes = [ctypes.c_void_p] + self.lib.FreeCMemory.restype = None msg_ptr = self.lib.ReceiveMessageC(self._network_instance) if msg_ptr is None: continue @@ -117,9 +124,9 @@ class AgentMessageManager(): continue self.msg_queue.put(msg) self.lib.FreeCMemory(msg_ptr) - if msg.msg_type == "exit": - self.lib.DestroyNetwork(self._network_instance) - return + #if msg.msg_type == "exit": + # self.lib.DestroyNetwork(self._network_instance) + # return def get_network_instance(self): """ @@ -205,7 +212,7 @@ def get_message_manager() -> AgentMessageManager: return AgentMessageManager.instance -def network_send_message(msg: MessageInfo): +def network_send_message(msg: MessageInfo, code: int = 0): """ Send message to taskd manager. """ @@ -213,7 +220,7 @@ def network_send_message(msg: MessageInfo): if msg_manager.get_network_instance() is None: run_log.warning("network instance is None") return - msg_manager.send_message(msg) + msg_manager.send_message(msg, code) def get_msg_network_instance(): diff --git a/component/taskd/taskd/python/framework/agent/base_agent/base_agent.py b/component/taskd/taskd/python/framework/agent/base_agent/base_agent.py index dc3c20cce54e7297bba4e50d53a56c177b777d0e..345182aa04d06a9f350235bc80d96a3afc87edf0 100644 --- a/component/taskd/taskd/python/framework/agent/base_agent/base_agent.py +++ b/component/taskd/taskd/python/framework/agent/base_agent/base_agent.py @@ -34,7 +34,7 @@ DEFAULT_DST = { } -REPORT_CODE = 601 +REPORT_CODE = 202 DEFAULT_MSG_TYPE = "DEFAULT" STATUS_MSG_TYPE = "STATUS" @@ -76,15 +76,16 @@ class BaseAgent: raise NotImplementedError def handle_message(self): - try: - item = self.msg_queue.get_nowait() - except queue.Empty: - run_log.debug('msg_queue is empty') - return - self.command_map.get(item.MsgType)(item) + while not self.msg_queue.empty(): + try: + item = self.msg_queue.get_nowait() + except queue.Empty: + run_log.debug('msg_queue is empty') + return + self.command_map.get(item.code)(item) def grace_exit(self, msg): - run_log.info(f'receive {msg.msg_type} command, start to grace exit workers') + run_log.info(f'receive {msg.code} command, start to grace exit workers') try: grace_exit_pids(self.pids) except Exception as e: @@ -94,7 +95,10 @@ class BaseAgent: def send_message_to_manager(self, command, code, report_info): - report_json = json.dumps(asdict(report_info)) + if isinstance(report_info, str): + report_json = report_info + else: + report_json = json.dumps(asdict(report_info)) msg_body = MsgBody( msg_type=command, code=code, @@ -108,7 +112,7 @@ class BaseAgent: dst=DEFAULT_DST, body=body_json ) - network_send_message(msg_info) + network_send_message(msg_info, code) def check_network(self): time_cost = 0 diff --git a/component/taskd/taskd/python/framework/agent/ms_agent/ms_agent.py b/component/taskd/taskd/python/framework/agent/ms_agent/ms_agent.py index 23b94fc8fda57ec2e702ca99d1253a18abbbe6e5..7ed398985b2dd61a4bb5e3f1d2db8dd575ec2f74 100644 --- a/component/taskd/taskd/python/framework/agent/ms_agent/ms_agent.py +++ b/component/taskd/taskd/python/framework/agent/ms_agent/ms_agent.py @@ -50,7 +50,7 @@ class MsAgent(BaseAgent): self.command_map = { 'START': self.initialize_workers, 'STOP': self.stop_workers, - 'EXIT': self.exit_agent, + constants.EXITAGENTCODE: self.exit_agent, 'RESTART': self.restart_workers, 'GRACE_EXIT': self.grace_exit, } @@ -123,20 +123,20 @@ class MsAgent(BaseAgent): def initialize_workers(self, msg): - run_log.info(f'receive {msg.msg_type} command, start to initialize workers') + run_log.info(f'receive {msg.code} command, start to initialize workers') self._func_map.get('START_ALL_WORKER')() def stop_workers(self, msg): - run_log.info(f'receive {msg.msg_type} command, start to stop workers') + run_log.info(f'receive {msg.code} command, start to stop workers') self._func_map.get('KILL_WORKER')([constants.KILL_ALL_WORKERS]) def exit_agent(self, msg): - run_log.info(f'receive {msg.msg_type} command, start to exit agent') + run_log.info(f'receive {msg.code} command, start to exit agent') self._func_map.get('KILL_WORKER')([constants.KILL_ALL_WORKERS]) self.send_message_to_manager('STATUS', REPORT_CODE, AgentReportInfo()) exit(1) def restart_workers(self, msg): - run_log.info(f'receive {msg.msg_type} command, start to restart workers') + run_log.info(f'receive {msg.code} command, start to restart workers') self._func_map.get('KILL_WORKER')([constants.KILL_ALL_WORKERS]) self._func_map.get('START_ALL_WORKER')() diff --git a/component/taskd/taskd/python/framework/agent/pt_agent/pt_agent.py b/component/taskd/taskd/python/framework/agent/pt_agent/pt_agent.py index 9e3b6ea2758457cf3a851c945be9bc76dc71ed93..18772707c91e8f8a97133ced510a7e773212089e 100644 --- a/component/taskd/taskd/python/framework/agent/pt_agent/pt_agent.py +++ b/component/taskd/taskd/python/framework/agent/pt_agent/pt_agent.py @@ -21,6 +21,7 @@ from taskd.python.utils.log import run_log from taskd.python.framework.agent.base_agent.agent_network import init_network_client from taskd.python.framework.agent.base_agent.base_agent import BaseAgent, REPORT_CODE from taskd.python.framework.common.type import AgentReportInfo +from taskd.python.toolkit.constants import constants try: from torch.distributed.elastic.agent.server.api import WorkerState, RunResult except ImportError: @@ -41,10 +42,10 @@ class PtAgent(BaseAgent): self.local_world_size = cls._worker_group.spec.local_world_size self.network_config = network_config self.command_map = { - 'START': self.initialize_workers, + 205: self.initialize_workers, 'STOP': self.stop_workers, - 'EXIT': self.exit_agent, - 'RESTART': self.restart_workers, + constants.EXITAGENTCODE: self.exit_agent, + 206: self.restart_workers, 'GRACE_EXIT': self.grace_exit, } self.logger = logger @@ -55,7 +56,7 @@ class PtAgent(BaseAgent): spec = self.worker_group.spec role = spec.role run_log.info("[%s] starting workers for entrypoint: %s", role, spec.get_entrypoint_name()) - self._func_map.get('START_ALL_WORKER')(self.worker_group) + self.start_worker() self.update_agent_info() monitor_interval = spec.monitor_interval @@ -97,25 +98,42 @@ class PtAgent(BaseAgent): return def initialize_workers(self, msg): - run_log.info(f'receive {msg.msg_type} command, restart time is {msg.extension},' + run_log.info(f'receive {msg.code} command, restart time is {msg.extension},' f' start to initialize workers') - self.pt_instance._remaining_restarts = int(msg.extension) + self.pt_instance._remaining_restarts = int(msg.message) self._func_map.get('START_ALL_WORKER')(self.worker_group) def stop_workers(self, msg): - run_log.info(f'receive {msg.msg_type} command, start to stop workers') + run_log.info(f'receive {msg.code} command, start to stop workers') self._func_map.get('KILL_WORKER')(self.worker_group) self.worker_group.state = WorkerState.STOPPED def exit_agent(self, msg): - run_log.info(f'receive {msg.msg_type} command, start to exit agent') + run_log.info(f'receive {msg.code} command, start to exit agent') self._func_map.get('KILL_WORKER')(self.worker_group) - self.send_message_to_manager('STATUS', REPORT_CODE, AgentReportInfo()) + self.send_message_to_manager('STATUS', constants.EXITAGENTCODE, AgentReportInfo()) exit(1) def restart_workers(self, msg): - run_log.info(f'receive {msg.msg_type} command, start to restart workers, restart time is {msg.extension}') - self.pt_instance._remaining_restarts = int(msg.extension) + run_log.info(f'receive {msg.code} command, start to restart workers, restart time is {msg.message}') + self.pt_instance._remaining_restarts = int(msg.message) self._func_map.get('KILL_WORKER')(self.worker_group) self.worker_group.state = WorkerState.STOPPED self._func_map.get('START_ALL_WORKER')(self.worker_group) + + def start_worker(self): + time_use = 0 + self.send_message_to_manager('STATUS', 201, str(self.pt_instance._remaining_restarts)) + run_log.info(f"agent {self.node_rank} start worker, restart times is {self.pt_instance._remaining_restarts}") + while True: + try: + item = self.msg_queue.get_nowait() + if item.code == 205: + self.command_map.get(item.code)(item) + break + except queue.Empty: + run_log.debug('msg_queue is empty') + time.sleep(1) + if time_use > 15: + raise RuntimeError("start_worker timeout") + diff --git a/component/taskd/taskd/python/toolkit/constants/constants.py b/component/taskd/taskd/python/toolkit/constants/constants.py index 17440fdc919c0df63d7121f1906f8aa7af0712d2..bef5e2d5eae50f3aed79dd47b0bae022e56f8fb8 100644 --- a/component/taskd/taskd/python/toolkit/constants/constants.py +++ b/component/taskd/taskd/python/toolkit/constants/constants.py @@ -107,4 +107,6 @@ STOP_TRAIN_PAUSE = "pause" SWITCH_NIC_DEFAULT_TIMEOUT = 600 SWITCH_NIC_MAX_TIMEOUT = 120 * 60 -HCCL_CONNECT_TIMEOUT = "HCCL_CONNECT_TIMEOUT" \ No newline at end of file +HCCL_CONNECT_TIMEOUT = "HCCL_CONNECT_TIMEOUT" + +EXITAGENTCODE = 203 \ No newline at end of file