diff --git a/init/Cargo.toml b/init/Cargo.toml index 0f3cac884df0c252b174b2d66688a004d9e92e88..f7cb4e736949d6a7daaa471de1428bf83c06c0e5 100644 --- a/init/Cargo.toml +++ b/init/Cargo.toml @@ -1,14 +1,45 @@ [package] name = "init" -version = "0.2.4" edition = "2021" - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +version = "0.2.4" [dependencies] -nix = "0.24" -tokio = { version = "1.29.1", features = ["net", "sync", "process", "rt-multi-thread", "macros", "signal", "time"] } -clap = { version = "3.1.8", features = ["derive"] } -once_cell = "1.18.0" -log = { version = "0.4", features = ["std"] } -psutil = { version = "*", default_features = false, features = ["process"]} +kernlog = "0.3.1" +nix = { features = [ + "fs", + "net", + "signal", + "time", + "user", +], default-features = false, version = "0.24" } +serde = { default-features = false, version = "1.0.130" } + +[dependencies.confique] +version = "0.1.3" +features = ["toml"] +optional = false +default-features = false + +[dependencies.libc] +version = "0.2.138" +features = [] +optional = false +default-features = false + +[dependencies.log] +version = "0.4" +features = ["std"] +optional = false +default-features = false + +[dependencies.mio] +version = "0.8.8" +features = ["os-poll", "os-ext"] +optional = false +default-features = false + +[dependencies.psutil] +version = "3.2.2" +features = ["process"] +optional = false +default-features = false diff --git a/init/src/init.conf b/init/src/init.conf new file mode 100644 index 0000000000000000000000000000000000000000..a18bf0d9d066098d1272147ea3db43bf136fdaae --- /dev/null +++ b/init/src/init.conf @@ -0,0 +1,9 @@ + # [init] + #[config(default = 10)] + timecnt = 9 + #[config(default = 90)] + timewait = 100 + #[config(default = "/usr/lib/sysmaster/sysmaster")] + bin = "/bin/ls" + #[config(default = "/run/sysmaster/init.sock")] + socket = "init.sock" diff --git a/init/src/main.rs b/init/src/main.rs index daf8279a63ad5e5575354a559f3495557c567663..c1b2ee6d0aa7200774211913f6047013908ea12f 100644 --- a/init/src/main.rs +++ b/init/src/main.rs @@ -11,105 +11,93 @@ // See the Mulan PSL v2 for more details. //! The init daemon -use clap::Parser; -use log::{Level, LevelFilter, Log}; -use nix::sys::{ - signal::Signal, - wait::{waitid, Id, WaitPidFlag, WaitStatus}, -}; -use once_cell::sync::OnceCell; +use confique::{Config, FileFormat, Partial}; +use mio::unix::SourceFd; +use mio::Events; +use mio::Interest; +use mio::Poll; +use mio::Token; +use nix::sys::signal; +use nix::sys::signal::SaFlags; +use nix::sys::signal::SigAction; +use nix::sys::signal::SigHandler; +use nix::sys::signal::Signal; +use nix::sys::signalfd::SigSet; +use nix::sys::signalfd::SignalFd; +use nix::sys::socket::getsockopt; +use nix::sys::socket::sockopt::PeerCredentials; +use nix::sys::stat::umask; +use nix::sys::stat::Mode; +use nix::sys::time::TimeSpec; +use nix::sys::time::TimeValLike; +use nix::sys::timerfd::ClockId; +use nix::sys::timerfd::Expiration; +use nix::sys::timerfd::TimerFd; +use nix::sys::timerfd::TimerFlags; +use nix::sys::timerfd::TimerSetTimeFlags; +use nix::sys::wait::waitid; +use nix::sys::wait::Id; +use nix::sys::wait::WaitPidFlag; +use nix::sys::wait::WaitStatus; +use nix::unistd; +use nix::unistd::execv; +use nix::unistd::Pid; +#[allow(unused_imports)] +use nix::unistd::Uid; use psutil::process::Process; -use std::{ - fs::{self, File, OpenOptions}, - io::Write, - path::{Path, PathBuf}, - process, - sync::Arc, - time::Duration, -}; -use tokio::{ - net::UnixListener, - process::Command, - signal::unix::{signal, SignalKind}, - sync::RwLock, - time::sleep, -}; - -struct InitLog { - kmsg: std::sync::Mutex, - maxlevel: LevelFilter, -} - -impl InitLog { - pub fn new(filter: LevelFilter) -> InitLog { - InitLog { - kmsg: std::sync::Mutex::new(OpenOptions::new().write(true).open("/dev/kmsg").unwrap()), - maxlevel: filter, - } - } +use std::ffi::CString; +use std::fs; +use std::io; +use std::os::unix::io::AsRawFd; +use std::os::unix::net::UnixListener; +use std::path::Path; +use std::path::PathBuf; +use std::process::Command; +use std::time::Duration; + +const ALLFD_TOKEN: Token = Token(0); +const TIMERFD_TOKEN: Token = Token(1); +const SIGNALFD_TOKEN: Token = Token(2); +const SOCKETFD_TOKEN: Token = Token(3); +#[cfg(not(test))] +const INIT_SOCK: &str = "/run/sysmaster/init.sock"; +#[cfg(test)] +const INIT_SOCK: &str = "init.sock"; +const INIT_CONFIG: &str = "/etc/sysmaster/init.conf"; - pub fn init(filter: LevelFilter) { - let klog = InitLog::new(filter); - _ = log::set_boxed_logger(Box::new(klog)); - log::set_max_level(filter); - } +#[derive(Config, Debug)] +struct InitConfig { + #[config(default = 10)] + timecnt: usize, + #[config(default = 90)] + timewait: i64, + #[config(default = "/usr/lib/sysmaster/sysmaster")] + bin: String, } -impl Log for InitLog { - fn enabled(&self, metadata: &log::Metadata) -> bool { - metadata.level() <= self.maxlevel - } - - fn log(&self, record: &log::Record) { - if record.level() > self.maxlevel { - return; - } - - let level: u8 = match record.level() { - Level::Error => 3, - Level::Warn => 4, - Level::Info => 5, - Level::Debug => 6, - Level::Trace => 7, - }; - - let mut buf = Vec::new(); - writeln!( - buf, - "<{}>{}[{}]: {}", - level, - record.target(), - process::id(), - record.args() +impl InitConfig { + pub fn load(&mut self, file: Option) -> std::io::Result { + type ConfigPartial = ::Partial; + let partial: ConfigPartial = match confique::File::with_format( + file.unwrap_or_else(|| INIT_CONFIG.to_string()), + FileFormat::Toml, ) - .unwrap(); - - if let Ok(mut kmsg) = self.kmsg.lock() { - let _ = kmsg.write(&buf); - let _ = kmsg.flush(); + .load() + { + Err(_) => return Ok(InitConfig::default()), + Ok(v) => v, + }; + match InitConfig::from_partial(partial) { + Ok(v) => Ok(v), + Err(_) => Ok(InitConfig::default()), } } - - fn flush(&self) {} } -#[derive(Parser, Debug)] -#[clap(author, version, about, long_about = None)] -struct InitOptions { - /// Number of monitored and reset instances. - #[clap(long, value_parser, default_value = "10")] - timecnt: usize, - /// Waiting time for monitoring and keeping alive. - #[clap(long, value_parser, default_value = "90")] - timewait: u64, - /// Subcommands for init - #[clap(long, value_parser, default_value = "/usr/lib/sysmaster/sysmaster")] - bin: String, - /// socket path - #[clap(long, value_parser, default_value = "/run/sysmaster/init.sock")] - socket: String, - /// Other options - args: Option, +impl Default for InitConfig { + fn default() -> Self { + InitConfig::from_partial(::Partial::default_values()).unwrap() + } } #[derive(Debug, Clone, PartialEq, Eq)] @@ -120,414 +108,586 @@ enum InitState { } struct Runtime { - options: InitOptions, + poll: Poll, + timerfd: TimerFd, + signalfd: SignalFd, + socketfd: UnixListener, + config: InitConfig, state: InitState, // sysmaster pid - pid: i32, - debug: bool, + pid: u32, // sysmaster status online: bool, + deserialize: bool, } -fn runtime() -> &'static RwLock { - static INSTANCE: OnceCell> = OnceCell::new(); - INSTANCE.get_or_init(|| { - #[cfg(not(test))] - let ret = RwLock::new(Runtime::new().unwrap()); - #[cfg(test)] - let ret = RwLock::new(Runtime { - options: InitOptions { - timecnt: 1, - timewait: 0, - bin: "/usr/bin/ls".to_string(), - socket: "/tmp/sysmaster/init.socket".to_string(), - args: None, - }, - state: InitState::Running, - pid: 0, - debug: false, - online: false, - }); - ret - }) -} +impl Runtime { + pub fn new() -> std::io::Result { + // parse arguments, --pid, Invisible to user + let mut pid = 0u32; + let mut args = std::env::args().skip(1); + while let Some(arg) = args.next() { + match arg.as_str() { + "--pid" => { + if let Some(value) = args.next() { + if !value.starts_with('-') { + pid = match value.parse::() { + Ok(v) => v, + Err(e) => panic!("invalid value: {e:?}"), + }; + } else { + panic!("Missing or invalid value for option."); + } + } else { + panic!("Missing value for option --deserialize."); + } + } + _ => { + println!("Unknown argument: {}, ignored!", arg); + } + } + } -macro_rules! runtime_read { - () => { - crate::runtime().read().await - }; -} + // check socket + let sock_path = PathBuf::from(INIT_SOCK); + let sock_parent = sock_path.parent().unwrap(); + if !sock_parent.exists() { + fs::create_dir_all(sock_parent)?; + } + if fs::metadata(INIT_SOCK).is_ok() { + let _ = fs::remove_file(INIT_SOCK); + } + let socketfd = UnixListener::bind(INIT_SOCK)?; + + // add signal + let mut mask = SigSet::empty(); + for sig in [ + Signal::SIGINT, + Signal::SIGTERM, + Signal::SIGCHLD, + Signal::SIGHUP, + ] { + mask.add(sig); + } + mask.thread_set_mask()?; + let signalfd = SignalFd::new(&mask)?; -macro_rules! runtime_write { - () => { - crate::runtime().write().await - }; -} + // set timer + let timerfd = TimerFd::new(ClockId::CLOCK_MONOTONIC, TimerFlags::TFD_NONBLOCK)?; + timerfd.set( + Expiration::OneShot(TimeSpec::from_duration(Duration::from_nanos(1))), + TimerSetTimeFlags::empty(), + )?; + + // parse config + let mut config = InitConfig::default(); + config.load(None)?; -impl Runtime { - // This is an issue with the clippy tool. The new function is used under the not(test) configuration. - #[allow(dead_code)] - pub fn new() -> std::io::Result { Ok(Self { - options: InitOptions::parse(), + poll: Poll::new()?, + timerfd, + signalfd, + socketfd, + config, state: InitState::Init, - pid: 0, - debug: false, + pid, online: false, + deserialize: pid != 0, }) } - pub fn is_running(&self) -> bool { - self.state == InitState::Running - } + pub fn register(&mut self, token: Token) -> std::io::Result<()> { + let binding = self.signalfd.as_raw_fd(); + let mut signal_source = SourceFd(&binding); + let binding = self.timerfd.as_raw_fd(); + let mut time_source = SourceFd(&binding); + let binding = self.socketfd.as_raw_fd(); + let mut unix_source = SourceFd(&binding); + + match token { + SIGNALFD_TOKEN => self.poll.registry().register( + &mut signal_source, + SIGNALFD_TOKEN, + Interest::READABLE, + )?, + TIMERFD_TOKEN => self.poll.registry().register( + &mut time_source, + TIMERFD_TOKEN, + Interest::READABLE, + )?, + SOCKETFD_TOKEN => self.poll.registry().register( + &mut unix_source, + SOCKETFD_TOKEN, + Interest::READABLE, + )?, + _ => { + self.poll.registry().register( + &mut signal_source, + SIGNALFD_TOKEN, + Interest::READABLE, + )?; + self.poll.registry().register( + &mut time_source, + TIMERFD_TOKEN, + Interest::READABLE, + )?; + self.poll.registry().register( + &mut unix_source, + SOCKETFD_TOKEN, + Interest::READABLE, + )?; + } + } - pub fn socket(&self) -> String { - self.options.socket.clone() + Ok(()) } - pub fn bin(&self) -> String { - self.options.bin.clone() - } + pub fn deregister(&mut self, token: Token) -> std::io::Result<()> { + let binding = self.signalfd.as_raw_fd(); + let mut signal_source = SourceFd(&binding); + let binding = self.timerfd.as_raw_fd(); + let mut time_source = SourceFd(&binding); + let binding = self.socketfd.as_raw_fd(); + let mut unix_source = SourceFd(&binding); + + match token { + SIGNALFD_TOKEN => self.poll.registry().deregister(&mut signal_source)?, + TIMERFD_TOKEN => self.poll.registry().deregister(&mut time_source)?, + SOCKETFD_TOKEN => self.poll.registry().deregister(&mut unix_source)?, + _ => { + self.poll.registry().deregister(&mut signal_source)?; + self.poll.registry().deregister(&mut time_source)?; + self.poll.registry().deregister(&mut unix_source)?; + } + } - pub fn pid(&self) -> i32 { - self.pid + Ok(()) } - pub fn set_pid(&mut self, pid: i32) { - self.pid = pid; + fn load_config(&mut self) -> std::io::Result<()> { + match self.config.load(None) { + Ok(_) => Ok(()), + Err(e) => { + log::error!("Failed to load config, error: {}, ignored!", e); + Ok(()) + } + } } - pub fn _state(&self) -> InitState { - self.state.clone() - } + fn reap_zombie(&self) { + // peek signal + let flags = WaitPidFlag::WEXITED | WaitPidFlag::WNOHANG | WaitPidFlag::WNOWAIT; + loop { + let wait_status = match waitid(Id::All, flags) { + Ok(status) => status, + Err(_) => return, + }; + + let si = match wait_status { + WaitStatus::Exited(pid, code) => Some((pid, code, Signal::SIGCHLD)), + WaitStatus::Signaled(pid, signal, _dc) => Some((pid, -1, signal)), + _ => None, // ignore + }; + + // check + let (pid, _, _) = match si { + Some((pid, code, sig)) => (pid, code, sig), + None => { + log::debug!("Ignored child signal: {:?}!", wait_status); + return; + } + }; - pub fn set_state(&mut self, state: InitState) { - self.state = state; - } + if pid.as_raw() <= 0 { + log::debug!("pid:{:?} is invalid! Ignored it.", pid); + return; + } - pub fn debug(&self) -> bool { - true + // pop: recycle the zombie + if let Err(e) = waitid(Id::Pid(pid), WaitPidFlag::WEXITED) { + log::error!("Error when reap the zombie({:?}), ignoring: {:?}!", pid, e); + } else { + log::debug!("reap the zombie: pid:{:?}.", pid); + } + } } - pub fn set_debug(&mut self, debug: bool) { - self.debug = debug; + pub fn handle_signal(&mut self) -> std::io::Result<()> { + let sig = match self.signalfd.read_signal()? { + Some(s) => s, + None => return Ok(()), + }; + match Signal::try_from(sig.ssi_signo as i32)? { + Signal::SIGHUP => self.reload()?, + Signal::SIGINT => { + println!("Received SIGINT for pid({:?})", sig.ssi_pid); + self.exit(1); + } + Signal::SIGKILL => self.kill_manager(), + Signal::SIGTERM => self.state = InitState::Reexec, + Signal::SIGCHLD => self.reap_zombie(), + _ => { + println!( + "Received signo {:?} for pid({:?}), ignored!", + sig.ssi_signo, sig.ssi_pid + ); + } + }; + Ok(()) } - pub fn online(&self) -> bool { - self.online - } + pub fn handle_timer(&mut self) -> std::io::Result<()> { + if self.config.timecnt == 0 { + log::error!( + "Restarted {} times, {} seconds each time, will not continue.", + self.config.timecnt, + self.config.timewait + ); + self.deregister(TIMERFD_TOKEN)?; + self.deregister(SOCKETFD_TOKEN)?; + return Ok(()); + } - pub fn set_online(&mut self, online: bool) { - self.online = online; + if self.online { + log::debug!("Set online to false!"); + self.online = false; + } else { + self.start_bin(); + self.config.timecnt -= 1; + log::error!( + "Reamining {} times to start {}!", + self.config.timecnt, + self.config.bin + ); + } + self.timerfd.set( + Expiration::OneShot(TimeSpec::seconds(self.config.timewait)), + TimerSetTimeFlags::empty(), + )?; + Ok(()) } - pub fn timecnt(&self) -> usize { - self.options.timecnt - } + pub fn handle_socket(&mut self) -> std::io::Result<()> { + let (stream, _) = match self.socketfd.accept() { + Ok((connection, address)) => (connection, address), + Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + // If we get a `WouldBlock` error we know our + // listener has no more incoming connections queued, + // so we can return to polling and wait for some + // more. + return Ok(()); + } + Err(e) => { + // If it was any other kind of error, something went + // wrong and we terminate with an error. + log::error!("Error accepting connection: {}!", e); + return Err(e); + } + }; - pub fn timewait(&self) -> u64 { - self.options.timewait + let credentials = getsockopt(stream.as_raw_fd(), PeerCredentials)?; + let pid = credentials.pid(); + if pid == self.pid.try_into().unwrap() { + // If the incoming PID is not the monitored sysmaster, + // do not refresh the status. + self.online = true; + log::debug!("Set online to true!"); + log::debug!("Keepalive: receive a heartbeat from pid({})!", pid); + } + Ok(()) } -} -async fn bootup() { - log::info!("bootup started."); - log::info!("bootup completed."); -} - -fn rape_zombie() { - // peek signal - let flags = WaitPidFlag::WEXITED | WaitPidFlag::WNOHANG | WaitPidFlag::WNOWAIT; - loop { - let wait_status = match waitid(Id::All, flags) { - Ok(status) => status, - Err(_) => return, - }; - - let si = match wait_status { - WaitStatus::Exited(pid, code) => Some((pid, code, Signal::SIGCHLD)), - WaitStatus::Signaled(pid, signal, _dc) => Some((pid, -1, signal)), - _ => None, // ignore - }; + fn start_bin(&mut self) { + // check sysmaster status, if it is running then sigterm it + if self.pid != 0 { + if self.deserialize { + return; + } + if let Ok(process) = Process::new(self.pid) { + if process.is_running() { + _ = process.terminate(); + log::info!( + "Timeout: send SIGTERM to {} ({})!", + self.config.bin, + self.pid + ); + return; + } + } + } - // check - let (pid, _, _) = match si { - Some((pid, code, sig)) => (pid, code, sig), + // else start the binary + let mut parts = self.config.bin.split_whitespace(); + let command = match parts.next() { + Some(c) => c, None => { - log::debug!("Ignored child signal: {:?}!", wait_status); + log::error!("Wrong command: {:?}!", self.config.bin); return; } }; - - if pid.as_raw() <= 0 { - log::debug!("pid:{:?} is invalid! Ignored it.", pid); - return; + let args: Vec<&str> = parts.collect(); + if !Path::new(command).exists() { + log::error!("{:?} does not exest!", command); } - // pop: recycle the zombie - if let Err(e) = waitid(Id::Pid(pid), WaitPidFlag::WEXITED) { - log::error!("Error when rape the zombie({:?}), ignoring: {:?}!", pid, e); - } else { - log::debug!("rape the zombie: pid:{:?}.", pid); - } - } -} + let child_process = match Command::new(command).args(&args).spawn() { + Ok(child) => child, + Err(_) => { + log::error!("Failed to spawn process: {:?}.", command); + return; + } + }; -async fn signald() { - let mut sigterm = signal(SignalKind::terminate()).unwrap(); - let mut sigint = signal(SignalKind::interrupt()).unwrap(); - let mut sigchld = signal(SignalKind::child()).unwrap(); - let mut sighup = signal(SignalKind::hangup()).unwrap(); - // let mut sigfault = signal(SignalKind::from_raw(9)).unwrap(); + self.pid = child_process.id(); + log::info!("Startup: start {}({}))!", self.config.bin, self.pid); + } - loop { - if !runtime_read!().is_running() { - break; - } + pub fn runloop(&mut self) -> std::io::Result<()> { + self.register(ALLFD_TOKEN)?; + let mut events = Events::with_capacity(16); - tokio::select! { - _ = sigterm.recv() => { - runtime_write!().set_state(InitState::Reexec); + // event loop. + loop { + if !self.is_running() { + self.deregister(ALLFD_TOKEN)?; + break; } - signal = sigint.recv() => { - if signal.is_some() { - println!("Received signal: SIGINT!"); - let debug = !runtime_read!().debug(); - runtime_write!().set_debug(debug); - std::process::exit(-1); + + self.poll.poll(&mut events, None)?; + + // Process each event. + for event in events.iter() { + match event.token() { + SIGNALFD_TOKEN => self.handle_signal()?, + TIMERFD_TOKEN => self.handle_timer()?, + SOCKETFD_TOKEN => self.handle_socket()?, + _ => unreachable!(), } } - _ = sigchld.recv() => { - rape_zombie(); - } - _ = sighup.recv() => { - reload(); - } - }; + + #[cfg(test)] + self.set_state(InitState::Init); + } + + Ok(()) } -} -async fn keepalive() { - let rt_socket = runtime_read!().socket(); - let rt_bin = runtime_read!().bin(); + pub fn is_running(&self) -> bool { + self.state == InitState::Running + } - let sock_path = PathBuf::from(rt_socket.clone()); - let path = sock_path.parent().unwrap(); - if !path.exists() { - if let Err(e) = fs::create_dir_all(path) { - log::error!("Failed to create directory {path:?}: {e}!"); - return; - } + pub fn set_state(&mut self, state: InitState) { + self.state = state; } - if fs::metadata(sock_path.clone()).is_ok() { - let _ = fs::remove_file(sock_path); + + fn reload(&mut self) -> std::io::Result<()> { + log::info!("Reloading init configuration"); + self.load_config()?; + Ok(()) + } + + pub fn is_reexec(&self) -> bool { + self.state == InitState::Reexec } - let listener = Arc::new(UnixListener::bind(rt_socket.clone()).unwrap()); - loop { - if !runtime_read!().is_running() { - break; + pub fn reexec(&mut self) { + let exe = match std::env::current_exe().unwrap().file_name() { + Some(v) => v.to_string_lossy().to_string(), + None => "".to_string(), + }; + let pid = self.pid.to_string(); + + for arg0 in [&exe, "/init", "/sbin/init"] { + let argv = vec![ + <&str>::clone(&arg0).to_string(), + "--pid".to_string(), + pid.clone(), + ]; + let cstr_argv = argv + .iter() + .map(|str| std::ffi::CString::new(&**str).unwrap()) + .collect::>(); + log::info!("Reexecuting init: {:?}!", argv.as_slice()); + if let Err(e) = execv(&CString::new(<&str>::clone(&arg0)).unwrap(), &cstr_argv) { + log::error!("Execv {arg0:?} {argv:?} Failed: {e}!"); + }; } + } - match listener.accept().await { - Ok((stream, _)) => { - // get pid of connection socket - let cred = stream.peer_cred().unwrap(); - let pid = cred.pid().unwrap(); - let rt_pid = runtime_read!().pid(); - if pid == rt_pid { - // If the incoming PID is not the monitored sysmaster, - // do not refresh the status. - runtime_write!().set_online(true); - log::debug!( - "Keepalive: receive a heartbeat from {} ({})!", - rt_bin, - pid - ); - } + fn kill_manager(&mut self) { + if let Ok(process) = Process::new(self.pid) { + if process.is_running() { + _ = process.terminate(); + log::info!( + "Received SIGKILL : send SIGTERM to {} ({})!", + self.config.bin, + self.pid + ); } - Err(e) => log::error!("Error accepting connection: {}!", e), } } - let _ = fs::remove_file(rt_socket); + + fn exit(&self, i: i32) { + std::process::exit(i); + } } -async fn watchdog() { - let mut rt_timecnt = runtime_read!().timecnt(); - let rt_timewait = runtime_read!().timewait(); - let rt_bin = runtime_read!().bin(); +fn prepare_init() { + // version + let version = env!("CARGO_PKG_VERSION"); + log::info!("sysMaster init version: {}", version); + let args: Vec = std::env::args().collect(); + if args.contains(&String::from("--version")) || args.contains(&String::from("-V")) { + println!("sysMaster init version: {}!", version); + std::process::exit(0); + } - loop { - if !runtime_read!().is_running() || rt_timecnt == 0 { - break; - } + // common umask + let mode = Mode::from_bits_truncate(0o77); + umask(umask(mode) | Mode::from_bits_truncate(0o22)); - if !(runtime_read!().online()) { - // False for two consecutive times, - // indicating an abnormal state of sysmaster. - check_bin().await; - rt_timecnt -= 1; - #[cfg(test)] - runtime_write!().set_online(true); - } else { - runtime_write!().set_online(false); - } + // euid check + #[cfg(not(test))] + if unistd::geteuid() != Uid::from_raw(0) { + log::error!("init: must be superuser."); + std::process::exit(1); + } - sleep(Duration::from_secs(rt_timewait)).await; + if unistd::getpid() != Pid::from_raw(1) { + log::info!("init: running in test mode."); } - log::info!( - "Restarted {} {} times, {} seconds each time, will not continue.", - rt_bin, - rt_timecnt, - rt_timewait - ); } -async fn check_bin() { - // check sysmaster status, if it is running then sigterm it - let rt_pid = runtime_read!().pid(); - let rt_bin = runtime_read!().bin(); - if rt_pid != 0 { - if let Ok(process) = Process::new(rt_pid as u32) { - if process.is_running() { - _ = process.terminate(); - log::info!("Timeout: send SIGTERM to {} ({})!", rt_bin, rt_pid); - return; - } +fn reset_all_signal_handlers() { + // Create an empty signal set + let mut sigset = SigSet::empty(); + + // Add all signals to the signal set + for sig in signal::Signal::iterator() { + if sig == signal::Signal::SIGKILL || sig == signal::Signal::SIGSTOP { + continue; // Do not allow ignoring SIGKILL and SIGSTOP signals } + sigset.add(sig); } - // else start the binary - let mut parts = rt_bin.split_whitespace(); - let command = match parts.next() { - Some(c) => c, - None => { - log::error!("Wrong command: {:?}!", rt_bin); - return; + // Set the signal handler to be ignored + let sig_action = SigAction::new(SigHandler::SigIgn, SaFlags::SA_RESTART, SigSet::empty()); + for sig in sigset.iter() { + unsafe { + signal::sigaction(sig, &sig_action).expect("Failed to set signal handler"); } - }; - let args: Vec<&str> = parts.collect(); - if !Path::new(command).exists() { - log::error!("{:?} does not exest!", command); } +} - let child_process = match Command::new(command).args(&args).spawn() { - Ok(child) => child, - Err(_) => { - log::error!("Failed to spawn process: {:?}.", command); - return; - } - }; +extern "C" fn crash_handler(_signal: i32) { + log::error!("crash_handler"); +} - if let Some(pid) = child_process.id() { - runtime_write!().set_pid(pid as i32); - log::info!("Startup: start {}({}))!", rt_bin, pid); +fn install_crash_handler() { + let signals_crash_handler = [ + signal::SIGSEGV, + signal::SIGILL, + signal::SIGFPE, + signal::SIGBUS, + signal::SIGABRT, + ]; + let sig_action = SigAction::new( + SigHandler::Handler(crash_handler), + SaFlags::SA_SIGINFO | SaFlags::SA_NODEFER, + SigSet::empty(), + ); + + for sig in signals_crash_handler { + unsafe { + signal::sigaction(sig, &sig_action).expect("Failed to set crash signal handler"); + } } } -fn prepare_init() {} -fn reset_all_signal_handlers() {} -fn install_crash_handler() {} -fn reexec() {} -fn reload() {} fn shutdown_init() { nix::unistd::sync(); + log::info!("shutdowning init"); } -#[tokio::main] -async fn main() -> std::io::Result<()> { - InitLog::init(LevelFilter::Info); +fn main() -> std::io::Result<()> { + match kernlog::init() { + Ok(_) => (), + Err(e) => panic!("Unsupported when cannot log into /dev/kmsg : {e:?}"), + }; + prepare_init(); + reset_all_signal_handlers(); install_crash_handler(); - runtime_write!().set_state(InitState::Running); - let bootup_handle = tokio::spawn(async move { bootup().await }); - let signald_handle = tokio::spawn(async move { signald().await }); - let keepalive_handle = tokio::spawn(async move { keepalive().await }); - let watchdog_handle = tokio::spawn(async move { watchdog().await }); - - _ = tokio::join!( - bootup_handle, - signald_handle, - keepalive_handle, - watchdog_handle - ); - - log::info!("all tasks completed!"); + let mut rt = Runtime::new()?; + rt.set_state(InitState::Running); - reexec(); + rt.runloop()?; + if rt.is_reexec() { + rt.reexec(); + } shutdown_init(); Ok(()) } #[cfg(test)] -mod test { - use std::{os::unix::net::UnixStream, time::Duration}; - - use nix::{ - sys::signal::{kill, Signal}, - unistd::Pid, - }; - use tokio::time::sleep; +mod tests { + use super::*; + + #[test] + fn test_runtime() -> std::io::Result<()> { + let mut rt = Runtime::new()?; + rt.set_state(InitState::Running); + rt.config.timewait = 0; + rt.runloop()?; + assert_ne!(rt.timerfd.as_raw_fd(), 0); + assert_ne!(rt.signalfd.as_raw_fd(), 0); + assert_ne!(rt.socketfd.as_raw_fd(), 0); + Ok(()) + } - use crate::{keepalive, watchdog}; - use crate::{signald, InitState}; + #[test] + fn test_default_config() { + let config = InitConfig::default(); + assert_eq!(config.timecnt, 10); + assert_eq!(config.timewait, 90); + assert_eq!(config.bin, "/usr/lib/sysmaster/sysmaster"); + } - #[tokio::test] - async fn test_watchdog() { - runtime_write!().set_state(InitState::Running); - runtime_write!().set_online(false); - watchdog().await; - assert!(runtime_read!().online()); + #[test] + fn test_load_config() { + let mut config = InitConfig::default(); + let config = config.load(Some("/path/to/init.conf".to_string())).unwrap(); + assert_eq!(config.timecnt, 10); + assert_eq!(config.timewait, 90); + assert_eq!(config.bin, "/usr/lib/sysmaster/sysmaster"); } - #[tokio::test] - // In tests, signal handling is controlled and cannot be easily tested. - async fn test_signald() { - runtime_write!().set_state(InitState::Running); - let handle1 = tokio::spawn(async move { - loop { - sleep(Duration::from_millis(100)).await; - - match kill(Pid::this(), Signal::SIGTERM) { - Ok(_) => log::info!("Signal sent successfully."), - Err(err) => log::error!("Failed to send signal: {}", err), - } - } - }); - let handle2 = tokio::spawn(async move { signald().await }); - tokio::select! { - _ = handle1 => { } - _ = handle2 => { } - }; - assert_eq!(runtime_read!().state, InitState::Reexec); + #[test] + fn test_load_failed_config() { + let mut config = InitConfig::default(); + let config = config.load(Some("src/init.conf".to_string())).unwrap(); + assert_eq!(config.timecnt, 9); + assert_eq!(config.timewait, 100); + assert_eq!(config.bin, "/bin/ls"); } - #[tokio::test] - async fn test_keepalive() { - runtime_write!().set_state(InitState::Running); - let client = tokio::task::spawn(async move { - #[allow(unused_assignments)] - let mut connected = false; - loop { - sleep(Duration::from_millis(100)).await; - match UnixStream::connect(runtime_read!().socket()) { - Ok(_) => { - connected = true; - break; - } - Err(_) => continue, - }; - } - assert!(connected); - }); + #[test] + fn test_main() { + prepare_init(); + + reset_all_signal_handlers(); + install_crash_handler(); + let mut rt = Runtime::new().unwrap(); + rt.set_state(InitState::Running); + + rt.runloop().unwrap(); - let server = tokio::spawn(async move { keepalive().await }); - tokio::select! { - _ = server => { } - _ = client => { } + if rt.is_reexec() { + rt.reexec(); } + shutdown_init(); } }