From f2b0adb94712114a1140821ef8ff5fe9c5a8c85b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9E=97=E7=8E=A0=E4=BA=BA?= Date: Wed, 10 Jul 2024 18:35:53 +0800 Subject: [PATCH] fix: update options and ctx stop MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 林玠人 --- src/app/server/cmd/commands/root.go | 51 +++++---- src/app/server/cmd/options/options.go | 7 -- src/app/server/config/config.go | 106 +++++++++--------- .../server/network/controller/fileservice.go | 8 +- src/app/server/network/httpserver.go | 6 +- src/app/server/network/jwt/jwt.go | 6 +- src/app/server/network/session.go | 4 +- src/app/server/network/socketserver.go | 4 +- .../network/websocket/client_manager.go | 6 +- src/app/server/service/auth/casbin.go | 4 +- src/app/server/service/auth/casbin_test.go | 23 ++-- src/app/server/service/plugin/plugin.go | 2 +- src/dbmanager/db.go | 9 +- src/dbmanager/redismanager/redismanager.go | 5 +- 14 files changed, 122 insertions(+), 119 deletions(-) diff --git a/src/app/server/cmd/commands/root.go b/src/app/server/cmd/commands/root.go index 293af5aa..ac57ce5d 100644 --- a/src/app/server/cmd/commands/root.go +++ b/src/app/server/cmd/commands/root.go @@ -31,6 +31,7 @@ func NewServerCommand() *cobra.Command { conf, err := options.TryLoadFromDisk() if err == nil { s.ServerConfig = conf + config.OptionsConfig = conf klog.InfoS("TryLoadFromDisk pilotgo config !", "HttpServer", *s.ServerConfig.HttpServer) klog.InfoS("TryLoadFromDisk pilotgo config !", "SocketServer", *s.ServerConfig.SocketServer) klog.InfoS("TryLoadFromDisk pilotgo config !", "JWT", *s.ServerConfig.JWT) @@ -65,58 +66,62 @@ func NewServerCommand() *cobra.Command { return cmd } -func run(_ *options.ServerOptions, ctx context.Context, cmd *cobra.Command) error { +func run(opts *options.ServerOptions, ctx context.Context, cmd *cobra.Command) error { if atomic.LoadInt64(&conut) > 0 { return nil } atomic.AddInt64(&conut, 1) - config_file, err := cmd.Flags().GetString(flagconfig) - if err != nil { - return errors.Wrapf(err, "error accessing flag %s for command %s", flagconfig, cmd.Name()) - } - err = config.Init(config_file) - if err != nil { - fmt.Println("failed to load configure, exit..", err) - return err - } - if config.Config().Storage.Path == "" { + // config_file, err := cmd.Flags().GetString(flagconfig) + // if err != nil { + // return errors.Wrapf(err, "error accessing flag %s for command %s", flagconfig, cmd.Name()) + // } + // err = config.Init(config_file) + // if err != nil { + // fmt.Println("failed to load configure, exit..", err) + // return err + // } + config := opts.ServerConfig + if config.Storage.Path == "" { fmt.Println("Please set the path for file service storage in yaml") return errors.New("storage path is nil") } - if err := logger.Init(&config.Config().Logopts); err != nil { + if err := logger.Init(config.Logopts); err != nil { fmt.Printf("logger init failed, please check the config file: %s", err) return err } logger.Info("Thanks to choose PilotGo!") // redis db初始化 - if err := dbmanager.RedisdbInit(&config.Config().RedisDBinfo, ctx.Done()); err != nil { + if err := dbmanager.RedisdbInit(config.RedisDBinfo, ctx.Done()); err != nil { + if err == context.Canceled { + return nil + } logger.Error("redis db init failed, please check again: %s", err) return err } // mysql db初始化 - if err := dbmanager.MysqldbInit(&config.Config().MysqlDBinfo); err != nil { + if err := dbmanager.MysqldbInit(config.MysqlDBinfo); err != nil { logger.Error("mysql db init failed, please check again: %s", err) return err } // 启动agent socket server - if err := network.SocketServerInit(&config.Config().SocketServer, ctx.Done()); err != nil { + if err := network.SocketServerInit(config.SocketServer, ctx.Done()); err != nil { logger.Error("socket server init failed, error:%v", err) return err } //此处启动前端及REST http server - err = network.HttpServerInit(&config.Config().HttpServer, ctx.Done()) + err := network.HttpServerInit(config.HttpServer, ctx.Done()) if err != nil { logger.Error("HttpServerInit socket server init failed, error:%v", err) return err } // 启动所有功能模块服务 - if err := startServices(ctx.Done()); err != nil { + if err := startServices(config.MysqlDBinfo, ctx.Done()); err != nil { logger.Error("start services error: %s", err) return err } @@ -130,9 +135,9 @@ func run(_ *options.ServerOptions, ctx context.Context, cmd *cobra.Command) erro return nil } -func startServices(stopCh <-chan struct{}) error { +func startServices(mysqlInfo *options.MysqlDBInfo, stopCh <-chan struct{}) error { // 鉴权模块初始化 - auth.Casbin(&config.Config().MysqlDBinfo) + auth.Casbin(mysqlInfo) // 初始化plugin服务 plugin.Init(stopCh) @@ -164,6 +169,14 @@ func Run(s *options.ServerOptions, ctx context.Context, cmd *cobra.Command, conf klog.Warningln("config is change") cancelFunc() s.ServerConfig = &cfg + config.OptionsConfig = &cfg + klog.InfoS("watchConfig pilotgo config receive!", "HttpServer", cfg.HttpServer) + klog.InfoS("watchConfig pilotgo config receive!", "SocketServer", cfg.SocketServer) + klog.InfoS("watchConfig pilotgo config receive!", "JWT", cfg.JWT) + klog.InfoS("watchConfig pilotgo config receive!", "Logopts", cfg.Logopts) + klog.InfoS("watchConfig pilotgo config receive!", "RedisDBinfo", cfg.RedisDBinfo) + klog.InfoS("watchConfig pilotgo config receive!", "MysqlDBinfo", cfg.MysqlDBinfo) + klog.InfoS("watchConfig pilotgo config receive!", "Storage", cfg.Storage) cctx, cancelFunc = context.WithCancel(context.TODO()) go func() { if err := runer(s, cctx, cmd); err != nil { diff --git a/src/app/server/cmd/options/options.go b/src/app/server/cmd/options/options.go index f7e8ede4..92ba78cb 100644 --- a/src/app/server/cmd/options/options.go +++ b/src/app/server/cmd/options/options.go @@ -132,13 +132,6 @@ func (c *pilotgoConfig) watchConfig() <-chan ServerConfig { if err := viper.Unmarshal(cfg); err != nil { klog.Errorf("config reload error", err) } else { - klog.InfoS("watchConfig pilotgo config !", "HttpServer", cfg.HttpServer) - klog.InfoS("watchConfig pilotgo config !", "SocketServer", cfg.SocketServer) - klog.InfoS("watchConfig pilotgo config !", "JWT", cfg.JWT) - klog.InfoS("watchConfig pilotgo config !", "Logopts", cfg.Logopts) - klog.InfoS("watchConfig pilotgo config !", "RedisDBinfo", cfg.RedisDBinfo) - klog.InfoS("watchConfig pilotgo config !", "MysqlDBinfo", cfg.MysqlDBinfo) - klog.InfoS("watchConfig pilotgo config !", "Storage", cfg.Storage) if in.Op&fsnotify.Write != 0 && len(viper.AllKeys()) > 0 { c.cfgChangeCh <- *cfg } diff --git a/src/app/server/config/config.go b/src/app/server/config/config.go index 9d70f2f3..373c234c 100644 --- a/src/app/server/config/config.go +++ b/src/app/server/config/config.go @@ -14,68 +14,64 @@ ******************************************************************************/ package config -import ( - "time" +import "gitee.com/openeuler/PilotGo/app/server/cmd/options" - "gitee.com/openeuler/PilotGo/sdk/logger" - "gitee.com/openeuler/PilotGo/sdk/utils/config" -) +// type HttpServer struct { +// Addr string `yaml:"addr"` +// SessionCount int `yaml:"session_count"` +// SessionMaxAge int `yaml:"session_max_age"` +// Debug bool `yaml:"debug"` +// UseHttps bool `yaml:"use_https"` +// CertFile string `yaml:"cert_file"` +// KeyFile string `yaml:"key_file"` +// } -type HttpServer struct { - Addr string `yaml:"addr"` - SessionCount int `yaml:"session_count"` - SessionMaxAge int `yaml:"session_max_age"` - Debug bool `yaml:"debug"` - UseHttps bool `yaml:"use_https"` - CertFile string `yaml:"cert_file"` - KeyFile string `yaml:"key_file"` -} +// type SocketServer struct { +// Addr string `yaml:"addr"` +// } -type SocketServer struct { - Addr string `yaml:"addr"` -} +// type JWTConfig struct { +// SecretKey string `yaml:"secret_key"` +// } -type JWTConfig struct { - SecretKey string `yaml:"secret_key"` -} +// type MysqlDBInfo struct { +// HostName string `yaml:"host_name"` +// UserName string `yaml:"user_name"` +// Password string `yaml:"password"` +// DataBase string `yaml:"data_base"` +// Port int `yaml:"port"` +// } -type MysqlDBInfo struct { - HostName string `yaml:"host_name"` - UserName string `yaml:"user_name"` - Password string `yaml:"password"` - DataBase string `yaml:"data_base"` - Port int `yaml:"port"` -} +// type RedisDBInfo struct { +// RedisConn string `yaml:"redis_conn"` +// UseTLS bool `yaml:"use_tls"` +// RedisPwd string `yaml:"redis_pwd"` +// DefaultDB int `yaml:"defaultDB"` +// DialTimeout time.Duration `yaml:"dialTimeout"` +// EnableRedis bool `yaml:"enableRedis"` +// } -type RedisDBInfo struct { - RedisConn string `yaml:"redis_conn"` - UseTLS bool `yaml:"use_tls"` - RedisPwd string `yaml:"redis_pwd"` - DefaultDB int `yaml:"defaultDB"` - DialTimeout time.Duration `yaml:"dialTimeout"` - EnableRedis bool `yaml:"enableRedis"` -} +// type Storage struct { +// Path string `yaml:"path"` +// } -type Storage struct { - Path string `yaml:"path"` -} +// type ServerConfig struct { +// HttpServer HttpServer `yaml:"http_server"` +// SocketServer SocketServer `yaml:"socket_server"` +// JWT JWTConfig `api:"jwt"` +// Logopts logger.LogOpts `yaml:"log"` +// MysqlDBinfo MysqlDBInfo `yaml:"mysql"` +// RedisDBinfo RedisDBInfo `yaml:"redis"` +// Storage Storage `yaml:"storage"` +// } -type ServerConfig struct { - HttpServer HttpServer `yaml:"http_server"` - SocketServer SocketServer `yaml:"socket_server"` - JWT JWTConfig `api:"jwt"` - Logopts logger.LogOpts `yaml:"log"` - MysqlDBinfo MysqlDBInfo `yaml:"mysql"` - RedisDBinfo RedisDBInfo `yaml:"redis"` - Storage Storage `yaml:"storage"` -} +// var global_config ServerConfig -var global_config ServerConfig +// func Init(path string) error { +// return config.Load(path, &global_config) +// } -func Init(path string) error { - return config.Load(path, &global_config) -} - -func Config() *ServerConfig { - return &global_config -} +// func Config() *ServerConfig { +// return &global_config +// } +var OptionsConfig *options.ServerConfig diff --git a/src/app/server/network/controller/fileservice.go b/src/app/server/network/controller/fileservice.go index e97a9790..70137485 100644 --- a/src/app/server/network/controller/fileservice.go +++ b/src/app/server/network/controller/fileservice.go @@ -33,8 +33,8 @@ func Upload(c *gin.Context) { } filename := parsedURL.Query().Get("filename") - uploadPath := c.DefaultQuery("path", config.Config().Storage.Path) // 获取上传文件的保存路径,可以通过path设置上传路径 - if err := os.MkdirAll(uploadPath, os.ModePerm); err != nil { // 确保保存路径存在,如果不存在则创建 + uploadPath := c.DefaultQuery("path", config.OptionsConfig.Storage.Path) // 获取上传文件的保存路径,可以通过path设置上传路径 + if err := os.MkdirAll(uploadPath, os.ModePerm); err != nil { // 确保保存路径存在,如果不存在则创建 response.Fail(c, gin.H{"error": err.Error()}, "创建保存路径失败") return } @@ -62,7 +62,7 @@ func Upload(c *gin.Context) { } defer file.Close() - uploadPath := c.DefaultQuery("path", config.Config().Storage.Path) // 获取上传文件的保存路径,可以通过path设置上传路径 + uploadPath := c.DefaultQuery("path", config.OptionsConfig.Storage.Path) // 获取上传文件的保存路径,可以通过path设置上传路径 if err := os.MkdirAll(uploadPath, os.ModePerm); err != nil { // 确保保存路径存在,如果不存在则创建 response.Fail(c, gin.H{"error": err.Error()}, "保存路径创建失败") @@ -91,7 +91,7 @@ func Download(c *gin.Context) { filename := c.Param("filename") // 获取下载文件的路径,可以通过path设置 - downloadPath := c.DefaultQuery("path", config.Config().Storage.Path) + downloadPath := c.DefaultQuery("path", config.OptionsConfig.Storage.Path) // 构建完整的文件路径 filePath := filepath.Join(downloadPath, filename) diff --git a/src/app/server/network/httpserver.go b/src/app/server/network/httpserver.go index 9ff153d1..9d333db5 100644 --- a/src/app/server/network/httpserver.go +++ b/src/app/server/network/httpserver.go @@ -19,7 +19,7 @@ import ( "net/http" "strings" - sconfig "gitee.com/openeuler/PilotGo/app/server/config" + "gitee.com/openeuler/PilotGo/app/server/cmd/options" "gitee.com/openeuler/PilotGo/app/server/network/controller" "gitee.com/openeuler/PilotGo/app/server/network/controller/agentcontroller" "gitee.com/openeuler/PilotGo/app/server/network/controller/pluginapi" @@ -31,7 +31,7 @@ import ( "k8s.io/klog/v2" ) -func HttpServerInit(conf *sconfig.HttpServer, stopCh <-chan struct{}) error { +func HttpServerInit(conf *options.HttpServer, stopCh <-chan struct{}) error { if err := SessionManagerInit(conf); err != nil { return err } @@ -40,7 +40,7 @@ func HttpServerInit(conf *sconfig.HttpServer, stopCh <-chan struct{}) error { r := setupRouter() // 启动websocket服务 - go websocket.CliManager.Start() + go websocket.CliManager.Start(stopCh) shutdownCtx, cancel := context.WithCancel(context.Background()) defer cancel() srv := &http.Server{ diff --git a/src/app/server/network/jwt/jwt.go b/src/app/server/network/jwt/jwt.go index fe68ea95..24f82d39 100644 --- a/src/app/server/network/jwt/jwt.go +++ b/src/app/server/network/jwt/jwt.go @@ -49,7 +49,7 @@ func GenerateUserToken(user userservice.ReturnUser) (string, error) { }, } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - tokenString, err := token.SignedString([]byte(config.Config().JWT.SecretKey)) + tokenString, err := token.SignedString([]byte(config.OptionsConfig.JWT.SecretKey)) if err != nil { return "", err } @@ -105,7 +105,7 @@ func GeneratePluginToken(name, uuid string) (string, error) { }, } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - tokenString, err := token.SignedString([]byte(config.Config().JWT.SecretKey)) + tokenString, err := token.SignedString([]byte(config.OptionsConfig.JWT.SecretKey)) if err != nil { return "", err } @@ -131,7 +131,7 @@ func ParsePluginClaims(c *gin.Context) (*PluginClaims, error) { func parseToken(tokenString string, clames jwt.Claims) (*jwt.Token, error) { token, err := jwt.ParseWithClaims(tokenString, clames, func(token *jwt.Token) (i interface{}, err error) { - return []byte(config.Config().JWT.SecretKey), nil + return []byte(config.OptionsConfig.JWT.SecretKey), nil }) return token, err } diff --git a/src/app/server/network/session.go b/src/app/server/network/session.go index 6d201b9a..a2446e3c 100644 --- a/src/app/server/network/session.go +++ b/src/app/server/network/session.go @@ -5,7 +5,7 @@ import ( "sync" "time" - sconfig "gitee.com/openeuler/PilotGo/app/server/config" + "gitee.com/openeuler/PilotGo/app/server/cmd/options" "gitee.com/openeuler/PilotGo/sdk/logger" "github.com/google/uuid" ) @@ -26,7 +26,7 @@ type SessionInfo struct { sessionTime time.Time } -func SessionManagerInit(conf *sconfig.HttpServer) error { +func SessionManagerInit(conf *options.HttpServer) error { var sessionManage SessionManage sessionManage.Init(conf.SessionMaxAge, conf.SessionCount) return nil diff --git a/src/app/server/network/socketserver.go b/src/app/server/network/socketserver.go index 78af7f2c..9a7c21d1 100644 --- a/src/app/server/network/socketserver.go +++ b/src/app/server/network/socketserver.go @@ -19,7 +19,7 @@ import ( "strings" "gitee.com/openeuler/PilotGo/app/server/agentmanager" - sconfig "gitee.com/openeuler/PilotGo/app/server/config" + "gitee.com/openeuler/PilotGo/app/server/cmd/options" "gitee.com/openeuler/PilotGo/sdk/logger" "k8s.io/klog/v2" ) @@ -30,7 +30,7 @@ type SocketServer struct { OnStop func() } -func SocketServerInit(conf *sconfig.SocketServer, stopCh <-chan struct{}) error { +func SocketServerInit(conf *options.SocketServer, stopCh <-chan struct{}) error { server := &SocketServer{ // MessageProcesser: protocol.NewMessageProcesser(), OnAccept: agentmanager.AddandRunAgent, diff --git a/src/app/server/network/websocket/client_manager.go b/src/app/server/network/websocket/client_manager.go index 9d4bb5f7..4af298ef 100644 --- a/src/app/server/network/websocket/client_manager.go +++ b/src/app/server/network/websocket/client_manager.go @@ -4,6 +4,7 @@ import ( "sync" "gitee.com/openeuler/PilotGo/sdk/logger" + "k8s.io/klog/v2" ) var ( @@ -82,9 +83,12 @@ func (manager *ClientManager) EventUnregister(client *Client) { } // 管道处理程序 -func (manager *ClientManager) Start() { +func (manager *ClientManager) Start(stopCh <-chan struct{}) { for { select { + case <-stopCh: + klog.Warningln("websocket CliManager success exit") + return case conn := <-manager.Register: // 建立连接事件 manager.EventRegister(conn) diff --git a/src/app/server/service/auth/casbin.go b/src/app/server/service/auth/casbin.go index ec9d954f..a9ec7c0f 100644 --- a/src/app/server/service/auth/casbin.go +++ b/src/app/server/service/auth/casbin.go @@ -19,7 +19,7 @@ import ( "fmt" "sync" - sconfig "gitee.com/openeuler/PilotGo/app/server/config" + "gitee.com/openeuler/PilotGo/app/server/cmd/options" suser "gitee.com/openeuler/PilotGo/app/server/service/user" "gitee.com/openeuler/PilotGo/sdk/common" "gitee.com/openeuler/PilotGo/sdk/logger" @@ -45,7 +45,7 @@ const ( DomainPilotGo = "PilotGo-server" ) -func Casbin(conf *sconfig.MysqlDBInfo) { +func Casbin(conf *options.MysqlDBInfo) { text := ` [request_definition] r = sub, obj, act, domain diff --git a/src/app/server/service/auth/casbin_test.go b/src/app/server/service/auth/casbin_test.go index 415c6c7d..8d445008 100644 --- a/src/app/server/service/auth/casbin_test.go +++ b/src/app/server/service/auth/casbin_test.go @@ -16,12 +16,9 @@ package auth import ( "fmt" - "os" "testing" "github.com/stretchr/testify/assert" - - sconfig "gitee.com/openeuler/PilotGo/app/server/config" ) func TestGetRoles(t *testing.T) { @@ -44,15 +41,15 @@ func TestGetAllPolicy(t *testing.T) { fmt.Printf("policies: %v\n", policies) } -func TestMain(m *testing.M) { - err := sconfig.Init("D:\\tmp\\PilotGo-projects\\PilotGo\\config_server.yaml") - if err != nil { - fmt.Println("failed to load configure, exit..", err) - os.Exit(-1) - } +// func TestMain(m *testing.M) { +// err := sconfig.Init("D:\\tmp\\PilotGo-projects\\PilotGo\\config_server.yaml") +// if err != nil { +// fmt.Println("failed to load configure, exit..", err) +// os.Exit(-1) +// } - // 鉴权模块初始化 - Casbin(&sconfig.Config().MysqlDBinfo) +// // 鉴权模块初始化 +// Casbin(&sconfig.Config().MysqlDBinfo) - m.Run() -} +// m.Run() +// } diff --git a/src/app/server/service/plugin/plugin.go b/src/app/server/service/plugin/plugin.go index 5c1de0a1..f17dbfbc 100644 --- a/src/app/server/service/plugin/plugin.go +++ b/src/app/server/service/plugin/plugin.go @@ -350,7 +350,7 @@ func Handshake(url string, p *Plugin) error { if index > 0 { url = url[:index] } - port := strings.Split(config.Config().HttpServer.Addr, ":")[1] + port := strings.Split(config.OptionsConfig.HttpServer.Addr, ":")[1] url = strings.TrimRight(url, "/") + "/plugin_manage/bind?port=" + port logger.Debug("plugin url is:%s", url) diff --git a/src/dbmanager/db.go b/src/dbmanager/db.go index 69f9b63f..a943ce3b 100644 --- a/src/dbmanager/db.go +++ b/src/dbmanager/db.go @@ -1,7 +1,7 @@ package dbmanager import ( - sconfig "gitee.com/openeuler/PilotGo/app/server/config" + "gitee.com/openeuler/PilotGo/app/server/cmd/options" "gitee.com/openeuler/PilotGo/app/server/service/auditlog" "gitee.com/openeuler/PilotGo/app/server/service/batch" "gitee.com/openeuler/PilotGo/app/server/service/configfile" @@ -16,21 +16,22 @@ import ( "gitee.com/openeuler/PilotGo/dbmanager/redismanager" ) -func RedisdbInit(conf *sconfig.RedisDBInfo, stopCh <-chan struct{}) error { +func RedisdbInit(conf *options.RedisDBInfo, stopCh <-chan struct{}) error { err := redismanager.RedisInit( conf.RedisConn, conf.RedisPwd, conf.DefaultDB, conf.DialTimeout, conf.EnableRedis, - stopCh) + stopCh, + conf.UseTLS) if err != nil { return err } return nil } -func MysqldbInit(conf *sconfig.MysqlDBInfo) error { +func MysqldbInit(conf *options.MysqlDBInfo) error { _, err := mysqlmanager.MysqlInit( conf.HostName, conf.UserName, diff --git a/src/dbmanager/redismanager/redismanager.go b/src/dbmanager/redismanager/redismanager.go index cff88ace..8f5bdb4b 100644 --- a/src/dbmanager/redismanager/redismanager.go +++ b/src/dbmanager/redismanager/redismanager.go @@ -19,7 +19,6 @@ import ( "crypto/tls" "time" - "gitee.com/openeuler/PilotGo/app/server/config" "github.com/go-redis/redis/v8" "k8s.io/klog/v2" ) @@ -31,9 +30,9 @@ var ( global_redis *redis.Client ) -func RedisInit(redisConn, redisPwd string, defaultDB int, dialTimeout time.Duration, enableRedis bool, stopCh <-chan struct{}) error { +func RedisInit(redisConn, redisPwd string, defaultDB int, dialTimeout time.Duration, enableRedis bool, stopCh <-chan struct{}, useTLS bool) error { var cfg *redis.Options - if config.Config().RedisDBinfo.UseTLS { + if useTLS { cfg = &redis.Options{ Addr: redisConn, Password: redisPwd, -- Gitee