2 Star 3 Fork 0

YashanDB Community/yasrpc

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
server.go 10.75 KB
一键复制 编辑 原始数据 按行查看 历史
huangsiyuan 提交于 2024-08-08 17:17 +08:00 . feat: support carry map in context
package yasrpc
import (
"crypto/tls"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"reflect"
"strings"
"sync"
"time"
"git.yasdb.com/go/yasrpc/log"
)
const (
VerifiedNumber = 0x23a6e1
SecretKeyMaxLen = 256
)
const (
ConnectedSuccess = "200 Connected to Yas RPC"
// Defaults used by HandleHTTP
DefaultRPCPath = "/_yasRPC_"
DefaultDebugPath = "/debug/yasrpc"
NetAlreadyClosedErr = "use of closed network connection"
)
type ServerInfo struct {
Method string
Headers map[string]interface{}
}
// Server represents an RPC Server.
type Server struct {
ln net.Listener
isNetClosed bool
secretKey string
lock sync.RWMutex
tlsConfig *tls.Config // tls tcp connetction config
serviceMap sync.Map // map[string]*service
Logger log.Logger
preInterceptor func(cc Codec, info *ServerInfo) error
}
// NewServer returns a new Server.
func NewServer(logger log.Logger, optionFns ...OptionFn) *Server {
if logger == nil {
logger = log.DefaultLogger
}
s := &Server{
Logger: logger,
}
for _, op := range optionFns {
op(s)
}
return s
}
// DefaultServer is the default instance of *Server.
var DefaultServer = NewServer(nil)
// Address return listened address
func (s *Server) Address() net.Addr {
s.lock.RLock()
defer s.lock.RUnlock()
if s.ln == nil {
return nil
}
return s.ln.Addr()
}
// Register publishes the receiver's methods
func (s *Server) Register(rcvr interface{}) error {
service, err := newService(rcvr, s.Logger)
if err != nil {
return err
}
if _, ok := s.serviceMap.LoadOrStore(service.name, service); ok {
return fmt.Errorf("rpc: service already defined: %s", service.name)
}
return nil
}
// Register publishes the receiver's methods in the DefaultServer.
func Register(rcvr interface{}) error { return DefaultServer.Register(rcvr) }
func (s *Server) findService(serviceMethod string) (svc *service, mtype *methodType, err error) {
dot := strings.LastIndex(serviceMethod, ".")
if dot < 0 {
err = fmt.Errorf("rpc server: service/method request invalid: %s", serviceMethod)
return
}
serviceName, methodName := serviceMethod[:dot], serviceMethod[dot+1:]
// Look up the request.
svci, ok := s.serviceMap.Load(serviceName)
if !ok {
err = fmt.Errorf("rpc server: can't find service: %s", serviceName)
return
}
svc = svci.(*service)
mtype = svc.method[methodName]
if mtype == nil {
err = fmt.Errorf("rpc server: can't find method: %s", methodName)
}
return
}
// Serve runs a rpc server and listen rpc requests
func (s *Server) Serve(network, address string) error {
ln, err := s.genListener(network, address)
if err != nil {
return err
}
if err := s.checkServer(); err != nil {
return err
}
return s.serveListener(ln)
}
// checkServer check server options.
func (s *Server) checkServer() error {
// check uuid length
if s.secretKey != "" {
keyLen := len(s.secretKey)
if keyLen < 0 || keyLen > SecretKeyMaxLen {
return fmt.Errorf("secret key length %d is not in the range 0-%d", keyLen, SecretKeyMaxLen)
}
}
return nil
}
func (s *Server) serveListener(ln net.Listener) error {
s.lock.Lock()
s.ln = ln
s.lock.Unlock()
return s.Accept(ln)
}
// Accept accepts connections on the listener and serves requests for each incoming connection.
func (s *Server) Accept(ln net.Listener) error {
var tempDelay time.Duration
for {
conn, err := ln.Accept()
if err != nil {
if ne, ok := err.(interface {
Temporary() bool
}); ok && ne.Temporary() {
if tempDelay == 0 {
tempDelay = 5 * time.Millisecond
} else {
tempDelay *= 2
}
if max := 1 * time.Second; tempDelay > max {
tempDelay = max
}
s.Logger.Errorf("rpc server: accept error: %v; retrying in %v", err, tempDelay)
time.Sleep(tempDelay)
continue
}
s.Logger.Debug(err.Error())
s.Logger.Debug(s.isNetClosed)
if !strings.Contains(err.Error(), NetAlreadyClosedErr) && !s.isNetClosed {
s.Logger.Errorf("rpc server: Accept = %v", err)
}
return err
}
tempDelay = 0
go s.ServeConn(conn)
}
}
// Accept accepts connections on the listener and serves requests for each incoming connection.
func Accept(ln net.Listener) error { return DefaultServer.Accept(ln) }
// ServeConn runs the server on a single connection.
// ServeConn blocks, serving the connection until the client hangs up.
func (s *Server) ServeConn(conn io.ReadWriteCloser) {
defer func() { _ = conn.Close() }()
var opt Option
if err := json.NewDecoder(conn).Decode(&opt); err != nil {
s.Logger.Errorf("rpc server: options err: %v", err)
return
}
if opt.VerifiedNumber != VerifiedNumber {
s.Logger.Errorf("rpc server: invalid magic number %x\n", opt.VerifiedNumber)
return
}
if s.secretKey != "" && s.secretKey != opt.SecretKey {
s.Logger.Errorf("rpc server: client uuid: %s is inconsistent with current uuid: %s", opt.SecretKey, s.secretKey)
return
}
fn := NewCodecFuncMap[opt.CodecType]
if fn == nil {
s.Logger.Errorf("rpc server: invalid codec type %s", opt.CodecType)
return
}
if err := json.NewEncoder(conn).Encode(&opt); err != nil {
s.Logger.Errorf("rpc server: options err: %v", err)
return
}
if opt.FileTransfer != nil {
s.serveTransfer(conn.(net.Conn), &opt)
} else {
s.serveCodec(fn(conn), &opt)
}
}
// invalidRequest is a placeholder for response argv when error occurs
var invalidRequest = struct{}{}
// ServeCodec is like ServeConn but uses the specified codec to
// decode requests and encode responses.
func (s *Server) serveCodec(cc Codec, opt *Option) {
sending := new(sync.Mutex)
wg := new(sync.WaitGroup)
for {
req, err := s.readRequest(cc)
if err != nil {
if req == nil {
// it's not possible to recover, so close the connection
break
}
req.h.Error = err.Error()
s.sendResponse(cc, req.h, invalidRequest, sending)
continue
}
if s.preInterceptor != nil {
serverInfo := &ServerInfo{
Method: req.h.ServiceMethod,
Headers: req.h.Headers,
}
err := s.preInterceptor(cc, serverInfo)
if err != nil {
if req == nil {
// it's not possible to recover, so close the connection
break
}
req.h.Error = err.Error()
s.sendResponse(cc, req.h, req.replyv.Interface(), sending)
continue
}
}
wg.Add(1)
go s.handleRequest(cc, req, sending, wg, opt.HandleTimeout)
}
wg.Wait()
_ = cc.Close()
}
func (s *Server) serveTransfer(conn net.Conn, opt *Option) {
switch opt.FileTransfer.TransferType {
case UploadType:
s.UploadFile(conn, opt.FileTransfer)
case DownloadType:
s.DownloadFile(conn, opt.FileTransfer)
default:
s.Logger.Errorf("invalid transfer type: %d", opt.FileTransfer.TransferType)
_ = conn.Close()
}
}
func (s *Server) Close() error {
s.lock.Lock()
defer s.lock.Unlock()
var err error
if s.ln != nil {
err = s.ln.Close()
s.isNetClosed = true
}
return err
}
// request stores all information of a call
type request struct {
h *Header // header of request
argv, replyv reflect.Value // argv and replyv of request
mtype *methodType
svc *service
}
func (s *Server) readRequestHeader(cc Codec) (*Header, error) {
var h Header
if err := cc.ReadHeader(&h); err != nil {
if err != io.EOF && err != io.ErrUnexpectedEOF {
s.Logger.Errorf("rpc server: read header err: %v", err)
}
return nil, err
}
return &h, nil
}
func (s *Server) readRequest(cc Codec) (*request, error) {
h, err := s.readRequestHeader(cc)
if err != nil {
return nil, err
}
req := &request{h: h}
req.svc, req.mtype, err = s.findService(h.ServiceMethod)
if err != nil {
// discard body
cc.ReadBody(nil)
return req, err
}
req.argv = req.mtype.newArgv()
req.replyv = req.mtype.newReplyv()
// make sure that argvi is a pointer, ReadBody need a pointer as parameter
argvi := req.argv.Interface()
if req.argv.Type().Kind() != reflect.Pointer {
argvi = req.argv.Addr().Interface()
}
if err := cc.ReadBody(argvi); err != nil {
s.Logger.Errorf("rpc server: read body err: %v", err)
return req, err
}
return req, nil
}
func (s *Server) sendResponse(cc Codec, h *Header, body interface{}, sending *sync.Mutex) {
sending.Lock()
defer sending.Unlock()
if err := cc.Write(h, body); err != nil {
s.Logger.Errorf("rpc server: write response err: %v", err)
}
}
func (s *Server) handleRequest(cc Codec, req *request, sending *sync.Mutex, wg *sync.WaitGroup, timeout time.Duration) {
defer wg.Done()
called := make(chan struct{})
sent := make(chan struct{})
finish := make(chan struct{})
defer close(finish)
go func() {
err := req.svc.call(req.mtype, req.argv, req.replyv)
select {
case <-finish:
close(called)
close(sent)
return
case called <- struct{}{}:
if err != nil {
req.h.Error = err.Error()
s.sendResponse(cc, req.h, req.replyv.Interface(), sending)
sent <- struct{}{}
return
}
s.sendResponse(cc, req.h, req.replyv.Interface(), sending)
sent <- struct{}{}
}
}()
if timeout == 0 {
<-called
<-sent
return
}
select {
case <-time.After(timeout):
req.h.Error = fmt.Sprintf("rpc server: request handle timeout: expect within %s", timeout.String())
s.sendResponse(cc, req.h, invalidRequest, sending)
case <-called:
<-sent
}
}
// ServerHTTP implements an http.Handler that answers RPC requests.
func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if req.Method != "CONNECT" {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.WriteHeader(http.StatusMethodNotAllowed)
_, _ = io.WriteString(w, "405 must CONNECT\n")
return
}
conn, _, err := w.(http.Hijacker).Hijack()
if err != nil {
s.Logger.Errorf("rpc hijacking %s: %v", req.RemoteAddr, err)
return
}
_, _ = io.WriteString(conn, fmt.Sprintf("HTTP/1.0 %s\n\n", ConnectedSuccess))
s.ServeConn(conn)
}
// HandelHTTP registers an HTTP handler for RPC messages on rpcPath,
// and a debugging handler on debugPath.
// It is still necessary to invoke http.Serve(), typically in a go statement.
func (s *Server) HandleHTTP(rpcPath, debugPath string) {
http.Handle(rpcPath, s)
http.Handle(debugPath, debugHTTP{s})
s.Logger.Infof("rpc server debug path: %s", DefaultDebugPath)
}
// HandleHTTP registers an HTTP handler for RPC messages to DefaultServer
// on DefaultRPCPath and a debugging handler on DefaultDebugPath.
// It is still necessary to invoke http.Serve(), typically in a go statement.
func HandleHTTP() {
DefaultServer.HandleHTTP(DefaultRPCPath, DefaultDebugPath)
}
func (s *Server) HandleWeb(ctx *Context) {
if ctx.Request.Method != "CONNECT" {
ctx.Writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
ctx.Writer.WriteHeader(http.StatusMethodNotAllowed)
_, _ = io.WriteString(ctx.Writer, "405 must CONNECT\n")
return
}
conn, _, err := ctx.Writer.(http.Hijacker).Hijack()
if err != nil {
s.Logger.Errorf("rpc hijacking %s: %v", ctx.Request.RemoteAddr, err)
return
}
_, _ = io.WriteString(conn, fmt.Sprintf("HTTP/1.0 %s\n\n", ConnectedSuccess))
s.ServeConn(conn)
}
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/yashan_tech/yasrpc.git
git@gitee.com:yashan_tech/yasrpc.git
yashan_tech
yasrpc
yasrpc
master

搜索帮助