diff --git a/sdk/etcd/gateway/middleware.go b/sdk/etcd/gateway/middleware.go new file mode 100644 index 0000000000000000000000000000000000000000..334963096b57aee3c7427140340df6fc361995db --- /dev/null +++ b/sdk/etcd/gateway/middleware.go @@ -0,0 +1,134 @@ +/* + * Copyright (c) KylinSoft Co., Ltd. 2024.All rights reserved. + * PilotGo licensed under the Mulan Permissive Software License, Version 2. + * See LICENSE file for more details. + * Author: zhanghan2021 + * Date: Tue Dec 10 14:46:05 2024 +0800 + */ +package gateway + +import ( + "net/http" + "sync" + "time" + + "gitee.com/openeuler/PilotGo/sdk/logger" + "golang.org/x/time/rate" +) + +// Middleware represents a chain of http handlers +type Middleware func(http.Handler) http.Handler + +// Chain chains multiple middleware together +func Chain(middlewares ...Middleware) Middleware { + return func(next http.Handler) http.Handler { + for i := len(middlewares) - 1; i >= 0; i-- { + next = middlewares[i](next) + } + return next + } +} + +// RateLimiter implements rate limiting middleware +type RateLimiter struct { + visitors map[string]*rate.Limiter + mu sync.RWMutex + r rate.Limit + b int +} + +func NewRateLimiter(r rate.Limit, b int) *RateLimiter { + return &RateLimiter{ + visitors: make(map[string]*rate.Limiter), + r: r, + b: b, + } +} + +func (rl *RateLimiter) getLimiter(ip string) *rate.Limiter { + rl.mu.Lock() + defer rl.mu.Unlock() + + limiter, exists := rl.visitors[ip] + if !exists { + limiter = rate.NewLimiter(rl.r, rl.b) + rl.visitors[ip] = limiter + } + + return limiter +} + +func (rl *RateLimiter) RateLimit(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + limiter := rl.getLimiter(r.RemoteAddr) + if !limiter.Allow() { + http.Error(w, "Too many requests", http.StatusTooManyRequests) + return + } + next.ServeHTTP(w, r) + }) +} + +// AuthMiddleware implements authentication middleware +type AuthMiddleware struct { + tokenValidator func(string) bool +} + +func NewAuthMiddleware(validator func(string) bool) *AuthMiddleware { + return &AuthMiddleware{ + tokenValidator: validator, + } +} + +func (am *AuthMiddleware) Authenticate(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + token := r.Header.Get("Authorization") + if token == "" { + http.Error(w, "Authorization token required", http.StatusUnauthorized) + return + } + + if !am.tokenValidator(token) { + http.Error(w, "Invalid token", http.StatusUnauthorized) + return + } + + next.ServeHTTP(w, r) + }) +} + +// LoggingMiddleware implements request logging +func LoggingMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + start := time.Now() + + // Call the next handler + next.ServeHTTP(w, r) + + // Log the request + duration := time.Since(start) + logger.Info( + "%s %s %s %v", + r.Method, + r.RequestURI, + r.RemoteAddr, + duration, + ) + }) +} + +// CORSMiddleware implements CORS support +func CORSMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization") + + if r.Method == "OPTIONS" { + w.WriteHeader(http.StatusOK) + return + } + + next.ServeHTTP(w, r) + }) +}