1 Star 0 Fork 1

vanve/server

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
clients.go 18.41 KB
一键复制 编辑 原始数据 按行查看 历史
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package mqtt
import (
"bufio"
"bytes"
"context"
"errors"
"fmt"
"io"
"net"
"sync"
"sync/atomic"
"time"
"github.com/rs/xid"
"github.com/mochi-mqtt/server/v2/packets"
)
const (
defaultKeepalive uint16 = 10 // the default connection keepalive value in seconds.
defaultClientProtocolVersion byte = 4 // the default mqtt protocol version of connecting clients (if somehow unspecified).
minimumKeepalive uint16 = 5 // the minimum recommended keepalive - values under with display a warning.
)
var (
ErrMinimumKeepalive = errors.New("client keepalive is below minimum recommended value and may exhibit connection instability")
)
// ReadFn is the function signature for the function used for reading and processing new packets.
type ReadFn func(*Client, packets.Packet) error
// Clients contains a map of the clients known by the broker.
type Clients struct {
internal map[string]*Client // clients known by the broker, keyed on client id.
sync.RWMutex
}
// NewClients returns an instance of Clients.
func NewClients() *Clients {
return &Clients{
internal: make(map[string]*Client),
}
}
// Add adds a new client to the clients map, keyed on client id.
func (cl *Clients) Add(val *Client) {
cl.Lock()
defer cl.Unlock()
cl.internal[val.ID] = val
}
// GetAll returns all the clients.
func (cl *Clients) GetAll() map[string]*Client {
cl.RLock()
defer cl.RUnlock()
m := map[string]*Client{}
for k, v := range cl.internal {
m[k] = v
}
return m
}
// Get returns the value of a client if it exists.
func (cl *Clients) Get(id string) (*Client, bool) {
cl.RLock()
defer cl.RUnlock()
val, ok := cl.internal[id]
return val, ok
}
// Len returns the length of the clients map.
func (cl *Clients) Len() int {
cl.RLock()
defer cl.RUnlock()
val := len(cl.internal)
return val
}
// Delete removes a client from the internal map.
func (cl *Clients) Delete(id string) {
cl.Lock()
defer cl.Unlock()
delete(cl.internal, id)
}
// GetByListener returns clients matching a listener id.
func (cl *Clients) GetByListener(id string) []*Client {
cl.RLock()
defer cl.RUnlock()
clients := make([]*Client, 0, cl.Len())
for _, client := range cl.internal {
if client.Net.Listener == id && !client.Closed() {
clients = append(clients, client)
}
}
return clients
}
// Client contains information about a client known by the broker.
type Client struct {
Properties ClientProperties // client properties
State ClientState // the operational state of the client.
Net ClientConnection // network connection state of the client
ID string // the client id.
ops *ops // ops provides a reference to server ops.
sync.RWMutex // mutex
}
// ClientConnection contains the connection transport and metadata for the client.
type ClientConnection struct {
Conn net.Conn // the net.Conn used to establish the connection
bconn *bufio.Reader // a buffered net.Conn for reading packets
outbuf *bytes.Buffer // a buffer for writing packets
Remote string // the remote address of the client
Listener string // listener id of the client
Inline bool // if true, the client is the built-in 'inline' embedded client
}
// ClientProperties contains the properties which define the client behaviour.
type ClientProperties struct {
Props packets.Properties
Will Will
Username []byte
ProtocolVersion byte
Clean bool
}
// Will contains the last will and testament details for a client connection.
type Will struct {
Payload []byte // -
User []packets.UserProperty // -
TopicName string // -
Flag uint32 // 0,1
WillDelayInterval uint32 // -
Qos byte // -
Retain bool // -
}
// ClientState tracks the state of the client.
type ClientState struct {
TopicAliases TopicAliases // a map of topic aliases
stopCause atomic.Value // reason for stopping
Inflight *Inflight // a map of in-flight qos messages
Subscriptions *Subscriptions // a map of the subscription filters a client maintains
disconnected int64 // the time the client disconnected in unix time, for calculating expiry
outbound chan *packets.Packet // queue for pending outbound packets
endOnce sync.Once // only end once
isTakenOver uint32 // used to identify orphaned clients
packetID uint32 // the current highest packetID
open context.Context // indicate that the client is open for packet exchange
cancelOpen context.CancelFunc // cancel function for open context
outboundQty int32 // number of messages currently in the outbound queue
Keepalive uint16 // the number of seconds the connection can wait
ServerKeepalive bool // keepalive was set by the server
}
// newClient returns a new instance of Client. This is almost exclusively used by Server
// for creating new clients, but it lives here because it's not dependent.
func newClient(c net.Conn, o *ops) *Client {
ctx, cancel := context.WithCancel(context.Background())
cl := &Client{
State: ClientState{
Inflight: NewInflights(),
Subscriptions: NewSubscriptions(),
TopicAliases: NewTopicAliases(o.options.Capabilities.TopicAliasMaximum),
open: ctx,
cancelOpen: cancel,
Keepalive: defaultKeepalive,
outbound: make(chan *packets.Packet, o.options.Capabilities.MaximumClientWritesPending),
},
Properties: ClientProperties{
ProtocolVersion: defaultClientProtocolVersion, // default protocol version
},
ops: o,
}
if c != nil {
cl.Net = ClientConnection{
Conn: c,
bconn: bufio.NewReaderSize(c, o.options.ClientNetReadBufferSize),
Remote: c.RemoteAddr().String(),
}
}
return cl
}
// WriteLoop ranges over pending outbound messages and writes them to the client connection.
func (cl *Client) WriteLoop() {
for {
select {
case pk := <-cl.State.outbound:
if err := cl.WritePacket(*pk); err != nil {
// TODO : Figure out what to do with error
cl.ops.log.Debug("failed publishing packet", "error", err, "client", cl.ID, "packet", pk)
}
atomic.AddInt32(&cl.State.outboundQty, -1)
case <-cl.State.open.Done():
return
}
}
}
// ParseConnect parses the connect parameters and properties for a client.
func (cl *Client) ParseConnect(lid string, pk packets.Packet) {
cl.Net.Listener = lid
cl.Properties.ProtocolVersion = pk.ProtocolVersion
cl.Properties.Username = pk.Connect.Username
cl.Properties.Clean = pk.Connect.Clean
cl.Properties.Props = pk.Properties.Copy(false)
if cl.Properties.Props.ReceiveMaximum > cl.ops.options.Capabilities.MaximumInflight { // 3.3.4 Non-normative
cl.Properties.Props.ReceiveMaximum = cl.ops.options.Capabilities.MaximumInflight
}
if pk.Connect.Keepalive <= minimumKeepalive {
cl.ops.log.Warn(
ErrMinimumKeepalive.Error(),
"client", cl.ID,
"keepalive", pk.Connect.Keepalive,
"recommended", minimumKeepalive,
)
}
cl.State.Keepalive = pk.Connect.Keepalive // [MQTT-3.2.2-22]
cl.State.Inflight.ResetReceiveQuota(int32(cl.ops.options.Capabilities.ReceiveMaximum)) // server receive max per client
cl.State.Inflight.ResetSendQuota(int32(cl.Properties.Props.ReceiveMaximum)) // client receive max
cl.State.TopicAliases.Outbound = NewOutboundTopicAliases(cl.Properties.Props.TopicAliasMaximum)
cl.ID = pk.Connect.ClientIdentifier
if cl.ID == "" {
cl.ID = xid.New().String() // [MQTT-3.1.3-6] [MQTT-3.1.3-7]
cl.Properties.Props.AssignedClientID = cl.ID
}
if pk.Connect.WillFlag {
cl.Properties.Will = Will{
Qos: pk.Connect.WillQos,
Retain: pk.Connect.WillRetain,
Payload: pk.Connect.WillPayload,
TopicName: pk.Connect.WillTopic,
WillDelayInterval: pk.Connect.WillProperties.WillDelayInterval,
User: pk.Connect.WillProperties.User,
}
if pk.Properties.SessionExpiryIntervalFlag &&
pk.Properties.SessionExpiryInterval < pk.Connect.WillProperties.WillDelayInterval {
cl.Properties.Will.WillDelayInterval = pk.Properties.SessionExpiryInterval
}
if pk.Connect.WillFlag {
cl.Properties.Will.Flag = 1 // atomic for checking
}
}
}
// refreshDeadline refreshes the read/write deadline for the net.Conn connection.
func (cl *Client) refreshDeadline(keepalive uint16) {
var expiry time.Time // nil time can be used to disable deadline if keepalive = 0
if keepalive > 0 {
expiry = time.Now().Add(time.Duration(keepalive+(keepalive/2)) * time.Second) // [MQTT-3.1.2-22]
}
if cl.Net.Conn != nil {
_ = cl.Net.Conn.SetDeadline(expiry) // [MQTT-3.1.2-22]
}
}
// NextPacketID returns the next available (unused) packet id for the client.
// If no unused packet ids are available, an error is returned and the client
// should be disconnected.
func (cl *Client) NextPacketID() (i uint32, err error) {
cl.Lock()
defer cl.Unlock()
i = atomic.LoadUint32(&cl.State.packetID)
started := i
overflowed := false
for {
if overflowed && i == started {
return 0, packets.ErrQuotaExceeded
}
if i >= cl.ops.options.Capabilities.maximumPacketID {
overflowed = true
i = 0
continue
}
i++
if _, ok := cl.State.Inflight.Get(uint16(i)); !ok {
atomic.StoreUint32(&cl.State.packetID, i)
return i, nil
}
}
}
// ResendInflightMessages attempts to resend any pending inflight messages to connected clients.
func (cl *Client) ResendInflightMessages(force bool) error {
if cl.State.Inflight.Len() == 0 {
return nil
}
for _, tk := range cl.State.Inflight.GetAll(false) {
if tk.FixedHeader.Type == packets.Publish {
tk.FixedHeader.Dup = true // [MQTT-3.3.1-1] [MQTT-3.3.1-3]
}
cl.ops.hooks.OnQosPublish(cl, tk, tk.Created, 0)
err := cl.WritePacket(tk)
if err != nil {
return err
}
if tk.FixedHeader.Type == packets.Puback || tk.FixedHeader.Type == packets.Pubcomp {
if ok := cl.State.Inflight.Delete(tk.PacketID); ok {
cl.ops.hooks.OnQosComplete(cl, tk)
atomic.AddInt64(&cl.ops.info.Inflight, -1)
}
}
}
return nil
}
// ClearInflights deletes all inflight messages for the client, e.g. for a disconnected user with a clean session.
func (cl *Client) ClearInflights() {
for _, tk := range cl.State.Inflight.GetAll(false) {
if ok := cl.State.Inflight.Delete(tk.PacketID); ok {
cl.ops.hooks.OnQosDropped(cl, tk)
atomic.AddInt64(&cl.ops.info.Inflight, -1)
}
}
}
// ClearExpiredInflights deletes any inflight messages which have expired.
func (cl *Client) ClearExpiredInflights(now, maximumExpiry int64) []uint16 {
deleted := []uint16{}
for _, tk := range cl.State.Inflight.GetAll(false) {
expired := tk.ProtocolVersion == 5 && tk.Expiry > 0 && tk.Expiry < now // [MQTT-3.3.2-5]
// If the maximum message expiry interval is set (greater than 0), and the message
// retention period exceeds the maximum expiry, the message will be forcibly removed.
enforced := maximumExpiry > 0 && now-tk.Created > maximumExpiry
if expired || enforced {
if ok := cl.State.Inflight.Delete(tk.PacketID); ok {
cl.ops.hooks.OnQosDropped(cl, tk)
atomic.AddInt64(&cl.ops.info.Inflight, -1)
deleted = append(deleted, tk.PacketID)
}
}
}
return deleted
}
// Read reads incoming packets from the connected client and transforms them into
// packets to be handled by the packetHandler.
func (cl *Client) Read(packetHandler ReadFn) error {
var err error
for {
if cl.Closed() {
return nil
}
cl.refreshDeadline(cl.State.Keepalive)
fh := new(packets.FixedHeader)
err = cl.ReadFixedHeader(fh)
if err != nil {
return err
}
pk, err := cl.ReadPacket(fh)
if err != nil {
return err
}
err = packetHandler(cl, pk) // Process inbound packet.
if err != nil {
return err
}
}
}
// Stop instructs the client to shut down all processing goroutines and disconnect.
func (cl *Client) Stop(err error) {
cl.State.endOnce.Do(func() {
if cl.Net.Conn != nil {
_ = cl.Net.Conn.Close() // omit close error
}
if err != nil {
cl.State.stopCause.Store(err)
}
if cl.State.cancelOpen != nil {
cl.State.cancelOpen()
}
atomic.StoreInt64(&cl.State.disconnected, time.Now().Unix())
})
}
// StopCause returns the reason the client connection was stopped, if any.
func (cl *Client) StopCause() error {
if cl.State.stopCause.Load() == nil {
return nil
}
return cl.State.stopCause.Load().(error)
}
// Closed returns true if client connection is closed.
func (cl *Client) Closed() bool {
return cl.State.open == nil || cl.State.open.Err() != nil
}
// ReadFixedHeader reads in the values of the next packet's fixed header.
func (cl *Client) ReadFixedHeader(fh *packets.FixedHeader) error {
if cl.Net.bconn == nil {
return ErrConnectionClosed
}
b, err := cl.Net.bconn.ReadByte()
if err != nil {
return err
}
err = fh.Decode(b)
if err != nil {
return err
}
var bu int
fh.Remaining, bu, err = packets.DecodeLength(cl.Net.bconn)
if err != nil {
return err
}
if cl.ops.options.Capabilities.MaximumPacketSize > 0 && uint32(fh.Remaining+1) > cl.ops.options.Capabilities.MaximumPacketSize {
return packets.ErrPacketTooLarge // [MQTT-3.2.2-15]
}
atomic.AddInt64(&cl.ops.info.BytesReceived, int64(bu+1))
return nil
}
// ReadPacket reads the remaining buffer into an MQTT packet.
func (cl *Client) ReadPacket(fh *packets.FixedHeader) (pk packets.Packet, err error) {
atomic.AddInt64(&cl.ops.info.PacketsReceived, 1)
pk.ProtocolVersion = cl.Properties.ProtocolVersion // inherit client protocol version for decoding
pk.FixedHeader = *fh
p := make([]byte, pk.FixedHeader.Remaining)
n, err := io.ReadFull(cl.Net.bconn, p)
if err != nil {
return pk, err
}
atomic.AddInt64(&cl.ops.info.BytesReceived, int64(n))
// Decode the remaining packet values using a fresh copy of the bytes,
// otherwise the next packet will change the data of this one.
px := append([]byte{}, p[:]...)
switch pk.FixedHeader.Type {
case packets.Connect:
err = pk.ConnectDecode(px)
case packets.Disconnect:
err = pk.DisconnectDecode(px)
case packets.Connack:
err = pk.ConnackDecode(px)
case packets.Publish:
err = pk.PublishDecode(px)
if err == nil {
atomic.AddInt64(&cl.ops.info.MessagesReceived, 1)
}
case packets.Puback:
err = pk.PubackDecode(px)
case packets.Pubrec:
err = pk.PubrecDecode(px)
case packets.Pubrel:
err = pk.PubrelDecode(px)
case packets.Pubcomp:
err = pk.PubcompDecode(px)
case packets.Subscribe:
err = pk.SubscribeDecode(px)
case packets.Suback:
err = pk.SubackDecode(px)
case packets.Unsubscribe:
err = pk.UnsubscribeDecode(px)
case packets.Unsuback:
err = pk.UnsubackDecode(px)
case packets.Pingreq:
case packets.Pingresp:
case packets.Auth:
err = pk.AuthDecode(px)
default:
err = fmt.Errorf("invalid packet type; %v", pk.FixedHeader.Type)
}
if err != nil {
return pk, err
}
pk, err = cl.ops.hooks.OnPacketRead(cl, pk)
return
}
// WritePacket encodes and writes a packet to the client.
func (cl *Client) WritePacket(pk packets.Packet) error {
if cl.Closed() {
return ErrConnectionClosed
}
if cl.Net.Conn == nil {
return nil
}
if pk.Expiry > 0 {
pk.Properties.MessageExpiryInterval = uint32(pk.Expiry - time.Now().Unix()) // [MQTT-3.3.2-6]
}
pk.ProtocolVersion = cl.Properties.ProtocolVersion
if pk.Mods.MaxSize == 0 { // NB we use this statement to embed client packet sizes in tests
pk.Mods.MaxSize = cl.Properties.Props.MaximumPacketSize
}
if cl.Properties.Props.RequestProblemInfoFlag && cl.Properties.Props.RequestProblemInfo == 0x0 {
pk.Mods.DisallowProblemInfo = true // [MQTT-3.1.2-29] strict, no problem info on any packet if set
}
if pk.FixedHeader.Type != packets.Connack || cl.Properties.Props.RequestResponseInfo == 0x1 || cl.ops.options.Capabilities.Compatibilities.AlwaysReturnResponseInfo {
pk.Mods.AllowResponseInfo = true // [MQTT-3.1.2-28] we need to know which properties we can encode
}
pk = cl.ops.hooks.OnPacketEncode(cl, pk)
var err error
buf := new(bytes.Buffer)
switch pk.FixedHeader.Type {
case packets.Connect:
err = pk.ConnectEncode(buf)
case packets.Connack:
err = pk.ConnackEncode(buf)
case packets.Publish:
err = pk.PublishEncode(buf)
case packets.Puback:
err = pk.PubackEncode(buf)
case packets.Pubrec:
err = pk.PubrecEncode(buf)
case packets.Pubrel:
err = pk.PubrelEncode(buf)
case packets.Pubcomp:
err = pk.PubcompEncode(buf)
case packets.Subscribe:
err = pk.SubscribeEncode(buf)
case packets.Suback:
err = pk.SubackEncode(buf)
case packets.Unsubscribe:
err = pk.UnsubscribeEncode(buf)
case packets.Unsuback:
err = pk.UnsubackEncode(buf)
case packets.Pingreq:
err = pk.PingreqEncode(buf)
case packets.Pingresp:
err = pk.PingrespEncode(buf)
case packets.Disconnect:
err = pk.DisconnectEncode(buf)
case packets.Auth:
err = pk.AuthEncode(buf)
default:
err = fmt.Errorf("%w: %v", packets.ErrNoValidPacketAvailable, pk.FixedHeader.Type)
}
if err != nil {
return err
}
if pk.Mods.MaxSize > 0 && uint32(buf.Len()) > pk.Mods.MaxSize {
return packets.ErrPacketTooLarge // [MQTT-3.1.2-24] [MQTT-3.1.2-25]
}
n, err := func() (int64, error) {
cl.Lock()
defer cl.Unlock()
if len(cl.State.outbound) == 0 {
if cl.Net.outbuf == nil {
return buf.WriteTo(cl.Net.Conn)
}
// first write to buffer, then flush buffer
n, _ := cl.Net.outbuf.Write(buf.Bytes()) // will always be successful
err = cl.flushOutbuf()
return int64(n), err
}
// there are more writes in the queue
if cl.Net.outbuf == nil {
if buf.Len() >= cl.ops.options.ClientNetWriteBufferSize {
return buf.WriteTo(cl.Net.Conn)
}
cl.Net.outbuf = new(bytes.Buffer)
}
n, _ := cl.Net.outbuf.Write(buf.Bytes()) // will always be successful
if cl.Net.outbuf.Len() < cl.ops.options.ClientNetWriteBufferSize {
return int64(n), nil
}
err = cl.flushOutbuf()
return int64(n), err
}()
if err != nil {
return err
}
atomic.AddInt64(&cl.ops.info.BytesSent, n)
atomic.AddInt64(&cl.ops.info.PacketsSent, 1)
if pk.FixedHeader.Type == packets.Publish {
atomic.AddInt64(&cl.ops.info.MessagesSent, 1)
}
cl.ops.hooks.OnPacketSent(cl, pk, buf.Bytes())
return err
}
func (cl *Client) flushOutbuf() (err error) {
if cl.Net.outbuf == nil {
return
}
_, err = cl.Net.outbuf.WriteTo(cl.Net.Conn)
if err == nil {
cl.Net.outbuf = nil
}
return
}
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/vanve/server.git
git@gitee.com:vanve/server.git
vanve
server
server
file-based-config

搜索帮助