diff --git a/component/taskd/taskd/go/common/constant/const.go b/component/taskd/taskd/go/common/constant/const.go index 799a81ffe96d6e0a3402a4b5f3609abb5c5aada4..612038ae2595600721faa83fc427eee22f524880 100644 --- a/component/taskd/taskd/go/common/constant/const.go +++ b/component/taskd/taskd/go/common/constant/const.go @@ -164,9 +164,10 @@ const ( // All cluster info type must be defined here const ( - ClusterRole = "Cluster" - ClusterDRank = "ClusterD" - TaskDRank = "TaskD" + ClusterRole = "Cluster" + ClusterDRank = "ClusterD" + TaskDRank = "TaskD" + ControllerRole = "Controller" ) // All cluster command must be defined here @@ -299,3 +300,18 @@ const ( // LocalProxyEnableOn local proxy enable value LocalProxyEnableOn = "on" ) + +const ( + // JobReschedulingPluginName name of job rescheduling plugin + JobReschedulingPluginName = "JobReschedulingPlugin" + // JobReschedulingStreamName name of job rescheduling stream + JobReschedulingStreamName = "JobReschedulingStream" + // SingalKillMaster singal kill master + SingalKillMaster = "killMaster" + // RestartController restart controller + RestartController = "restart_controller" + // DestroyController destroy controller + DestroyController = "destroy_controller" + // Actions actions + Actions = "actions" +) diff --git a/component/taskd/taskd/go/framework_backend/manager/infrastructure/storage/type.go b/component/taskd/taskd/go/framework_backend/manager/infrastructure/storage/type.go index 4cb936b17f81f072d48fd21a2a214352ad95aa44..0aa57bc0fabdfd71fc12174233052a68277fc062 100644 --- a/component/taskd/taskd/go/framework_backend/manager/infrastructure/storage/type.go +++ b/component/taskd/taskd/go/framework_backend/manager/infrastructure/storage/type.go @@ -42,3 +42,9 @@ type MsgBody struct { Message string `json:"message"` Extension map[string]string `json:"extension"` } + +// AgentReportInfo agent report info +type AgentReportInfo struct { + FaultRanks []int `json:"fault_ranks"` + RestartTime int `json:"restart_time"` +} diff --git a/component/taskd/taskd/go/framework_backend/manager/plugins/job_rescheduling/job_rescheduling.go b/component/taskd/taskd/go/framework_backend/manager/plugins/job_rescheduling/job_rescheduling.go new file mode 100644 index 0000000000000000000000000000000000000000..68ca75ff0ff7856fa1b660392d8e987eee16b906 --- /dev/null +++ b/component/taskd/taskd/go/framework_backend/manager/plugins/job_rescheduling/job_rescheduling.go @@ -0,0 +1,183 @@ +/* Copyright(C) 2025. Huawei Technologies Co.,Ltd. All rights reserved. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +// Package jobrescheduling for taskd manager plugin +package jobrescheduling + +import ( + "ascend-common/common-utils/hwlog" + "taskd/common/constant" + "taskd/framework_backend/manager/infrastructure" + "taskd/framework_backend/manager/infrastructure/storage" + "taskd/toolkit_backend/net/common" +) + +// JobReschedulingPlugin job rescheduling plugin +type JobReschedulingPlugin struct { + pullMsgs []infrastructure.Msg + faultOccur bool + processStatus string + agentStatus map[int]bool + killMaster bool +} + +var ( + agent0ExitMsg = infrastructure.Msg{ + Receiver: []string{common.AgentRole + "0"}, + Body: storage.MsgBody{ + MsgType: constant.Action, + Code: constant.ExitAgentCode, + }, + } +) + +// NewJobReschedulingPlugin new job rescheduling plugin +func NewJobReschedulingPlugin() infrastructure.ManagerPlugin { + return &JobReschedulingPlugin{ + pullMsgs: make([]infrastructure.Msg, 0), + faultOccur: false, + processStatus: "", + agentStatus: make(map[int]bool), + killMaster: false, + } +} + +// Name name of job rescheduling plugin +func (job *JobReschedulingPlugin) Name() string { + return constant.JobReschedulingPluginName +} + +// Handle handle job rescheduling plugin +func (job *JobReschedulingPlugin) Handle() (infrastructure.HandleResult, error) { + if job.killMaster { + hwlog.RunLog.Info("JobReschedulingPlugin Handle kill master") + job.processStatus = constant.HandleStageProcess + /*job.pullMsgs = append(job.pullMsgs, infrastructure.Msg{ + Receiver: []string{constant.ControllerRole}, + Body: storage.MsgBody{ + MsgType: constant.Action, + Code: 0, + Extension: map[string]string{ + constant.Actions: utils.ObjToString([]string{constant.DestroyController}), + }, + }, + })*/ + job.pullMsgs = append(job.pullMsgs, agent0ExitMsg) + return infrastructure.HandleResult{Stage: constant.HandleStageProcess}, nil + } + if !job.faultOccur { + hwlog.RunLog.Info("JobReschedulingPlugin not fault occur") + return infrastructure.HandleResult{ + Stage: constant.HandleStageFinal, + }, nil + } + + if value, ok := job.agentStatus[0]; ok && value == true { + hwlog.RunLog.Info("JobReschedulingPlugin agent 0 exit") + job.resetPluginInfo() + return infrastructure.HandleResult{ + Stage: constant.HandleStageFinal, + }, nil + } + hwlog.RunLog.Info("JobReschedulingPlugin handle fault") + job.pullMsgs = append(job.pullMsgs, agent0ExitMsg) + job.processStatus = constant.HandleStageProcess + return infrastructure.HandleResult{Stage: constant.HandleStageProcess}, nil +} + +// Handle handle job rescheduling plugin +func (job *JobReschedulingPlugin) PullMsg() ([]infrastructure.Msg, error) { + msgs := job.pullMsgs + job.pullMsgs = make([]infrastructure.Msg, 0) + return msgs, nil +} + +// Predicate predicate job rescheduling plugin +func (job *JobReschedulingPlugin) Predicate(shot storage.SnapShot) (infrastructure.PredicateResult, error) { + if job.processStatus != "" { + hwlog.RunLog.Infof("JobReschedulingPlugin Predicate processStatus:%v", job.processStatus) + job.updatePluginInfo(shot) + return infrastructure.PredicateResult{PluginName: job.Name(), + CandidateStatus: constant.CandidateStatus, + PredicateStream: map[string]string{ + constant.JobReschedulingStreamName: "", + }}, nil + } + clusterInfo, ok := shot.ClusterInfos.Clusters[constant.ClusterRole] + if ok { + if clusterInfo.Command[constant.SingalKillMaster] != "" { + hwlog.RunLog.Info("JobReschedulingPlugin Predicate kill master") + job.killMaster = true + return infrastructure.PredicateResult{PluginName: job.Name(), + CandidateStatus: constant.CandidateStatus, + PredicateStream: map[string]string{ + constant.JobReschedulingStreamName: "", + }}, nil + } + } + + for _, agent := range shot.AgentInfos.Agents { + if agent.Status[constant.ReportFaultRank] != "" { + job.faultOccur = true + job.processStatus = constant.HandleStageInit + hwlog.RunLog.Infof("JobReschedulingPlugin candidate token, info:%v", agent.Status[constant.ReportFaultRank]) + return infrastructure.PredicateResult{PluginName: job.Name(), + CandidateStatus: constant.CandidateStatus, + PredicateStream: map[string]string{ + constant.JobReschedulingStreamName: "", + }}, nil + } + } + hwlog.RunLog.Info("JobReschedulingPlugin not fault occur") + return infrastructure.PredicateResult{ + PluginName: job.Name(), CandidateStatus: constant.UnselectStatus, PredicateStream: nil}, nil +} + +// Release release job rescheduling plugin +func (job *JobReschedulingPlugin) Release() error { + return nil +} + +func (job *JobReschedulingPlugin) resetPluginInfo() { + job.pullMsgs = make([]infrastructure.Msg, 0) + job.faultOccur = false + job.processStatus = "" + job.agentStatus = make(map[int]bool) + job.killMaster = false +} + +func (job *JobReschedulingPlugin) updatePluginInfo(shot storage.SnapShot) { + agenInfo, ok := shot.AgentInfos.Agents[common.AgentRole+"0"] + if !ok { + hwlog.RunLog.Errorf("JobReschedulingPlugin updatePluginInfo agent 0 not exist") + job.resetPluginInfo() + return + } + if agenInfo.Status[constant.Exit] != "" { + job.agentStatus[0] = true + } + clusterInfo, ok := shot.ClusterInfos.Clusters[constant.ClusterRole] + if ok { + if clusterInfo.Command[constant.SingalKillMaster] != "" { + job.killMaster = true + } + } + + for _, agent := range shot.AgentInfos.Agents { + if agent.Status[constant.ReportFaultRank] != "" { + job.faultOccur = true + break + } + } +} diff --git a/component/taskd/taskd/go/framework_backend/manager/service/plugin_handler.go b/component/taskd/taskd/go/framework_backend/manager/service/plugin_handler.go index 59f48f146e14a0dd0475cfbe1289cf7583577a8c..d490bc67adea7a0276bf915cd91a0ccb24384e00 100644 --- a/component/taskd/taskd/go/framework_backend/manager/service/plugin_handler.go +++ b/component/taskd/taskd/go/framework_backend/manager/service/plugin_handler.go @@ -17,11 +17,14 @@ package service import ( "fmt" + "time" "ascend-common/common-utils/hwlog" + "taskd/common/constant" "taskd/framework_backend/manager/infrastructure" "taskd/framework_backend/manager/infrastructure/storage" "taskd/framework_backend/manager/plugins/faultdig" + jobrescheduling "taskd/framework_backend/manager/plugins/job_rescheduling" "taskd/framework_backend/manager/plugins/om" ) @@ -40,6 +43,8 @@ type PluginHandler struct { Plugins map[string]infrastructure.ManagerPlugin } +var timestart = time.Now() + // Init register all plugin func (p *PluginHandler) Init() error { profilingPlugin := faultdig.NewProfilingPlugin() @@ -52,6 +57,12 @@ func (p *PluginHandler) Init() error { hwlog.RunLog.Errorf("register plugin %s failed!", omPlugin.Name()) return fmt.Errorf("register plugin %s failed", omPlugin.Name()) } + jobReschedulingPlugin := jobrescheduling.NewJobReschedulingPlugin() + if err := p.Register(jobReschedulingPlugin.Name(), jobReschedulingPlugin); err != nil { + hwlog.RunLog.Errorf("register plugin %s failed!", jobReschedulingPlugin.Name()) + return fmt.Errorf("register plugin %s failed", jobReschedulingPlugin.Name()) + } + return nil } @@ -97,6 +108,14 @@ func (p *PluginHandler) Handle(pluginName string) (infrastructure.HandleResult, // Predicate execute the predicate function of all registered plugin func (p *PluginHandler) Predicate(snapshot *storage.SnapShot) []infrastructure.PredicateResult { var predicateResults []infrastructure.PredicateResult + currentTime := time.Now() + if currentTime.Unix()-timestart.Unix() > 30 { + command := make(map[string]string) + command[constant.SingalKillMaster] = "1" + snapshot.ClusterInfos.Clusters[constant.ClusterRole] = &storage.ClusterInfo{ + Command: command, + } + } for _, plugin := range p.Plugins { result, err := plugin.Predicate(*snapshot) if err != nil { diff --git a/component/taskd/taskd/go/framework_backend/manager/service/stream_handler.go b/component/taskd/taskd/go/framework_backend/manager/service/stream_handler.go index 51d7c4725171719a35bf8a87d61809b96e711a57..c649e1347f4aea33181d21fa9054b327e213357e 100644 --- a/component/taskd/taskd/go/framework_backend/manager/service/stream_handler.go +++ b/component/taskd/taskd/go/framework_backend/manager/service/stream_handler.go @@ -66,6 +66,13 @@ func (s *StreamHandler) Init() error { OmStream.GetName()) return err } + jobReschedulingStream := infrastructure.NewStream(constant.JobReschedulingStreamName, + map[string]int{constant.JobReschedulingPluginName: 1}) + if err := s.SetStream(jobReschedulingStream); err != nil { + hwlog.RunLog.Errorf("init stream handler failed: set stream %s failed", + jobReschedulingStream.GetName()) + return err + } return nil } diff --git a/component/taskd/taskd/python/framework/agent/base_agent/agent_network.py b/component/taskd/taskd/python/framework/agent/base_agent/agent_network.py index f4b9571aadcfe64dd8960c41a5e675e9182d0a6c..8273742e41b2416ce346d405927465b80e383a6d 100644 --- a/component/taskd/taskd/python/framework/agent/base_agent/agent_network.py +++ b/component/taskd/taskd/python/framework/agent/base_agent/agent_network.py @@ -27,6 +27,7 @@ from taskd.python.utils.log import run_log from taskd.python.cython_api import cython_api from taskd.python.framework.common.type import MsgBody, MessageInfo, Position, DEFAULT_BIZTYPE from taskd.python.toolkit.constants.constants import SEND_RETRY_TIMES +from taskd.python.toolkit.constants import constants class AgentMessageManager(): @@ -78,12 +79,13 @@ class AgentMessageManager(): run_log.info(f"agent register: {msg}") self.send_message(msg) - def send_message(self, message: MessageInfo): + def send_message(self, message: MessageInfo, code: int = 0): """ Send message to taskd manager. """ run_log.debug(f"agent send message: {message}") msg_json = json.dumps(asdict(message)).encode('utf-8') + self.lib.DestroyNetTool.argtypes = [ctypes.c_void_p] send_times = 0 self.lib.SyncSendMessage.argtypes = [ctypes.c_void_p, ctypes.c_char_p] self.lib.SyncSendMessage.restype = ctypes.c_int @@ -94,6 +96,9 @@ class AgentMessageManager(): result = self.lib.SyncSendMessage(self._network_instance, msg_json) if result == 0: run_log.info(f"agent send message success, msg: {message.uuid}") + if code == constants.EXITAGENTCODE: + run_log.info(f"agent send exit message, msg: {message.uuid}") + self.lib.DestroyNetTool(self._network_instance) break run_log.warning(f"agent send message failed, result: {result}") send_times += 1 @@ -106,6 +111,8 @@ class AgentMessageManager(): while True: self.lib.ReceiveMessageC.argtypes = [ctypes.c_void_p] self.lib.ReceiveMessageC.restype = ctypes.c_void_p + self.lib.FreeCMemory.argtypes = [ctypes.c_void_p] + self.lib.FreeCMemory.restype = None msg_ptr = self.lib.ReceiveMessageC(self._network_instance) if msg_ptr is None: continue @@ -117,9 +124,9 @@ class AgentMessageManager(): continue self.msg_queue.put(msg) self.lib.FreeCMemory(msg_ptr) - if msg.msg_type == "exit": - self.lib.DestroyNetwork(self._network_instance) - return + #if msg.msg_type == "exit": + # self.lib.DestroyNetwork(self._network_instance) + # return def get_network_instance(self): """ @@ -205,7 +212,7 @@ def get_message_manager() -> AgentMessageManager: return AgentMessageManager.instance -def network_send_message(msg: MessageInfo): +def network_send_message(msg: MessageInfo, code: int = 0): """ Send message to taskd manager. """ @@ -213,7 +220,7 @@ def network_send_message(msg: MessageInfo): if msg_manager.get_network_instance() is None: run_log.warning("network instance is None") return - msg_manager.send_message(msg) + msg_manager.send_message(msg, code) def get_msg_network_instance(): diff --git a/component/taskd/taskd/python/framework/agent/base_agent/base_agent.py b/component/taskd/taskd/python/framework/agent/base_agent/base_agent.py index dc3c20cce54e7297bba4e50d53a56c177b777d0e..6717053133647ce4d2be44182831380f4cbe2160 100644 --- a/component/taskd/taskd/python/framework/agent/base_agent/base_agent.py +++ b/component/taskd/taskd/python/framework/agent/base_agent/base_agent.py @@ -34,7 +34,7 @@ DEFAULT_DST = { } -REPORT_CODE = 601 +REPORT_CODE = 202 DEFAULT_MSG_TYPE = "DEFAULT" STATUS_MSG_TYPE = "STATUS" @@ -81,10 +81,10 @@ class BaseAgent: except queue.Empty: run_log.debug('msg_queue is empty') return - self.command_map.get(item.MsgType)(item) + self.command_map.get(item.code)(item) def grace_exit(self, msg): - run_log.info(f'receive {msg.msg_type} command, start to grace exit workers') + run_log.info(f'receive {msg.code} command, start to grace exit workers') try: grace_exit_pids(self.pids) except Exception as e: @@ -108,7 +108,7 @@ class BaseAgent: dst=DEFAULT_DST, body=body_json ) - network_send_message(msg_info) + network_send_message(msg_info, code) def check_network(self): time_cost = 0 diff --git a/component/taskd/taskd/python/framework/agent/ms_agent/ms_agent.py b/component/taskd/taskd/python/framework/agent/ms_agent/ms_agent.py index 23b94fc8fda57ec2e702ca99d1253a18abbbe6e5..7ed398985b2dd61a4bb5e3f1d2db8dd575ec2f74 100644 --- a/component/taskd/taskd/python/framework/agent/ms_agent/ms_agent.py +++ b/component/taskd/taskd/python/framework/agent/ms_agent/ms_agent.py @@ -50,7 +50,7 @@ class MsAgent(BaseAgent): self.command_map = { 'START': self.initialize_workers, 'STOP': self.stop_workers, - 'EXIT': self.exit_agent, + constants.EXITAGENTCODE: self.exit_agent, 'RESTART': self.restart_workers, 'GRACE_EXIT': self.grace_exit, } @@ -123,20 +123,20 @@ class MsAgent(BaseAgent): def initialize_workers(self, msg): - run_log.info(f'receive {msg.msg_type} command, start to initialize workers') + run_log.info(f'receive {msg.code} command, start to initialize workers') self._func_map.get('START_ALL_WORKER')() def stop_workers(self, msg): - run_log.info(f'receive {msg.msg_type} command, start to stop workers') + run_log.info(f'receive {msg.code} command, start to stop workers') self._func_map.get('KILL_WORKER')([constants.KILL_ALL_WORKERS]) def exit_agent(self, msg): - run_log.info(f'receive {msg.msg_type} command, start to exit agent') + run_log.info(f'receive {msg.code} command, start to exit agent') self._func_map.get('KILL_WORKER')([constants.KILL_ALL_WORKERS]) self.send_message_to_manager('STATUS', REPORT_CODE, AgentReportInfo()) exit(1) def restart_workers(self, msg): - run_log.info(f'receive {msg.msg_type} command, start to restart workers') + run_log.info(f'receive {msg.code} command, start to restart workers') self._func_map.get('KILL_WORKER')([constants.KILL_ALL_WORKERS]) self._func_map.get('START_ALL_WORKER')() diff --git a/component/taskd/taskd/python/framework/agent/pt_agent/pt_agent.py b/component/taskd/taskd/python/framework/agent/pt_agent/pt_agent.py index 9e3b6ea2758457cf3a851c945be9bc76dc71ed93..9d2305674c644f9643da58f285816f0ca2278843 100644 --- a/component/taskd/taskd/python/framework/agent/pt_agent/pt_agent.py +++ b/component/taskd/taskd/python/framework/agent/pt_agent/pt_agent.py @@ -21,6 +21,7 @@ from taskd.python.utils.log import run_log from taskd.python.framework.agent.base_agent.agent_network import init_network_client from taskd.python.framework.agent.base_agent.base_agent import BaseAgent, REPORT_CODE from taskd.python.framework.common.type import AgentReportInfo +from taskd.python.toolkit.constants import constants try: from torch.distributed.elastic.agent.server.api import WorkerState, RunResult except ImportError: @@ -43,7 +44,7 @@ class PtAgent(BaseAgent): self.command_map = { 'START': self.initialize_workers, 'STOP': self.stop_workers, - 'EXIT': self.exit_agent, + constants.EXITAGENTCODE: self.exit_agent, 'RESTART': self.restart_workers, 'GRACE_EXIT': self.grace_exit, } @@ -97,24 +98,24 @@ class PtAgent(BaseAgent): return def initialize_workers(self, msg): - run_log.info(f'receive {msg.msg_type} command, restart time is {msg.extension},' + run_log.info(f'receive {msg.code} command, restart time is {msg.extension},' f' start to initialize workers') self.pt_instance._remaining_restarts = int(msg.extension) self._func_map.get('START_ALL_WORKER')(self.worker_group) def stop_workers(self, msg): - run_log.info(f'receive {msg.msg_type} command, start to stop workers') + run_log.info(f'receive {msg.code} command, start to stop workers') self._func_map.get('KILL_WORKER')(self.worker_group) self.worker_group.state = WorkerState.STOPPED def exit_agent(self, msg): - run_log.info(f'receive {msg.msg_type} command, start to exit agent') + run_log.info(f'receive {msg.code} command, start to exit agent') self._func_map.get('KILL_WORKER')(self.worker_group) - self.send_message_to_manager('STATUS', REPORT_CODE, AgentReportInfo()) + self.send_message_to_manager('STATUS', constants.EXITAGENTCODE, AgentReportInfo()) exit(1) def restart_workers(self, msg): - run_log.info(f'receive {msg.msg_type} command, start to restart workers, restart time is {msg.extension}') + run_log.info(f'receive {msg.code} command, start to restart workers, restart time is {msg.extension}') self.pt_instance._remaining_restarts = int(msg.extension) self._func_map.get('KILL_WORKER')(self.worker_group) self.worker_group.state = WorkerState.STOPPED diff --git a/component/taskd/taskd/python/toolkit/constants/constants.py b/component/taskd/taskd/python/toolkit/constants/constants.py index 17440fdc919c0df63d7121f1906f8aa7af0712d2..bef5e2d5eae50f3aed79dd47b0bae022e56f8fb8 100644 --- a/component/taskd/taskd/python/toolkit/constants/constants.py +++ b/component/taskd/taskd/python/toolkit/constants/constants.py @@ -107,4 +107,6 @@ STOP_TRAIN_PAUSE = "pause" SWITCH_NIC_DEFAULT_TIMEOUT = 600 SWITCH_NIC_MAX_TIMEOUT = 120 * 60 -HCCL_CONNECT_TIMEOUT = "HCCL_CONNECT_TIMEOUT" \ No newline at end of file +HCCL_CONNECT_TIMEOUT = "HCCL_CONNECT_TIMEOUT" + +EXITAGENTCODE = 203 \ No newline at end of file