1 Star 0 Fork 1

vanve/server

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
server_test.go 107.70 KB
一键复制 编辑 原始数据 按行查看 历史
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348334933503351335233533354335533563357335833593360336133623363336433653366336733683369337033713372337333743375337633773378337933803381338233833384338533863387338833893390339133923393339433953396339733983399340034013402340334043405340634073408340934103411341234133414341534163417341834193420342134223423342434253426342734283429343034313432343334343435343634373438343934403441344234433444344534463447344834493450345134523453345434553456345734583459346034613462346334643465346634673468346934703471347234733474347534763477347834793480348134823483348434853486348734883489349034913492349334943495349634973498349935003501350235033504350535063507350835093510351135123513351435153516351735183519352035213522352335243525352635273528352935303531353235333534353535363537353835393540354135423543354435453546354735483549355035513552355335543555355635573558355935603561356235633564356535663567356835693570357135723573357435753576357735783579358035813582358335843585358635873588358935903591359235933594359535963597359835993600360136023603360436053606360736083609361036113612361336143615361636173618361936203621362236233624362536263627362836293630363136323633363436353636363736383639364036413642364336443645364636473648364936503651365236533654365536563657365836593660366136623663366436653666366736683669367036713672367336743675367636773678367936803681368236833684368536863687368836893690369136923693369436953696369736983699370037013702370337043705370637073708370937103711371237133714371537163717371837193720372137223723372437253726372737283729373037313732373337343735373637373738373937403741374237433744374537463747374837493750375137523753375437553756375737583759376037613762376337643765376637673768376937703771377237733774377537763777377837793780378137823783378437853786378737883789379037913792379337943795379637973798379938003801380238033804380538063807380838093810381138123813381438153816381738183819382038213822382338243825382638273828382938303831383238333834383538363837383838393840384138423843384438453846384738483849385038513852385338543855385638573858385938603861386238633864
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co
// SPDX-FileContributor: mochi-co
package mqtt
import (
"bytes"
"encoding/binary"
"io"
"log/slog"
"net"
"strconv"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/mochi-mqtt/server/v2/hooks/storage"
"github.com/mochi-mqtt/server/v2/listeners"
"github.com/mochi-mqtt/server/v2/packets"
"github.com/mochi-mqtt/server/v2/system"
"github.com/stretchr/testify/require"
)
var logger = slog.New(slog.NewTextHandler(io.Discard, nil))
type ProtocolTest []struct {
protocolVersion byte
in packets.TPacketCase
out packets.TPacketCase
data map[string]any
}
type AllowHook struct {
HookBase
}
func (h *AllowHook) SetOpts(l *slog.Logger, opts *HookOptions) {
h.Log = l
h.Opts = opts
}
func (h *AllowHook) ID() string {
return "allow-all-auth"
}
func (h *AllowHook) Provides(b byte) bool {
return bytes.Contains([]byte{OnConnectAuthenticate, OnACLCheck}, []byte{b})
}
func (h *AllowHook) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool { return true }
func (h *AllowHook) OnACLCheck(cl *Client, topic string, write bool) bool { return true }
type DenyHook struct {
HookBase
}
func (h *DenyHook) SetOpts(l *slog.Logger, opts *HookOptions) {
h.Log = l
h.Opts = opts
}
func (h *DenyHook) ID() string {
return "deny-all-auth"
}
func (h *DenyHook) Provides(b byte) bool {
return bytes.Contains([]byte{OnConnectAuthenticate, OnACLCheck}, []byte{b})
}
func (h *DenyHook) OnConnectAuthenticate(cl *Client, pk packets.Packet) bool { return false }
func (h *DenyHook) OnACLCheck(cl *Client, topic string, write bool) bool { return false }
type DelayHook struct {
HookBase
DisconnectDelay time.Duration
}
func (h *DelayHook) SetOpts(l *slog.Logger, opts *HookOptions) {
h.Log = l
h.Opts = opts
}
func (h *DelayHook) ID() string {
return "delay-hook"
}
func (h *DelayHook) Provides(b byte) bool {
return bytes.Contains([]byte{OnDisconnect}, []byte{b})
}
func (h *DelayHook) OnDisconnect(cl *Client, err error, expire bool) {
time.Sleep(h.DisconnectDelay)
}
func newServer() *Server {
cc := NewDefaultServerCapabilities()
cc.MaximumMessageExpiryInterval = 0
cc.ReceiveMaximum = 0
s := New(&Options{
Logger: logger,
Capabilities: cc,
})
_ = s.AddHook(new(AllowHook), nil)
return s
}
func newServerWithInlineClient() *Server {
cc := NewDefaultServerCapabilities()
cc.MaximumMessageExpiryInterval = 0
cc.ReceiveMaximum = 0
s := New(&Options{
Logger: logger,
Capabilities: cc,
InlineClient: true,
})
_ = s.AddHook(new(AllowHook), nil)
return s
}
func TestOptionsSetDefaults(t *testing.T) {
opts := &Options{}
opts.ensureDefaults()
require.Equal(t, defaultSysTopicInterval, opts.SysTopicResendInterval)
require.Equal(t, NewDefaultServerCapabilities(), opts.Capabilities)
opts = new(Options)
opts.ensureDefaults()
require.Equal(t, defaultSysTopicInterval, opts.SysTopicResendInterval)
}
func TestNew(t *testing.T) {
s := New(nil)
require.NotNil(t, s)
require.NotNil(t, s.Clients)
require.NotNil(t, s.Listeners)
require.NotNil(t, s.Topics)
require.NotNil(t, s.Info)
require.NotNil(t, s.Log)
require.NotNil(t, s.Options)
require.NotNil(t, s.loop)
require.NotNil(t, s.loop.sysTopics)
require.NotNil(t, s.loop.inflightExpiry)
require.NotNil(t, s.loop.clientExpiry)
require.NotNil(t, s.hooks)
require.NotNil(t, s.hooks.Log)
require.NotNil(t, s.done)
require.Nil(t, s.inlineClient)
require.Equal(t, 0, s.Clients.Len())
}
func TestNewWithInlineClient(t *testing.T) {
s := New(&Options{
InlineClient: true,
})
require.NotNil(t, s.inlineClient)
require.Equal(t, 1, s.Clients.Len())
}
func TestNewNilOpts(t *testing.T) {
s := New(nil)
require.NotNil(t, s)
require.NotNil(t, s.Options)
}
func TestServerNewClient(t *testing.T) {
s := New(nil)
s.Log = logger
r, _ := net.Pipe()
cl := s.NewClient(r, "testing", "test", false)
require.NotNil(t, cl)
require.Equal(t, "test", cl.ID)
require.Equal(t, "testing", cl.Net.Listener)
require.False(t, cl.Net.Inline)
require.NotNil(t, cl.State.Inflight.internal)
require.NotNil(t, cl.State.Subscriptions)
require.NotNil(t, cl.State.TopicAliases)
require.Equal(t, defaultKeepalive, cl.State.Keepalive)
require.Equal(t, defaultClientProtocolVersion, cl.Properties.ProtocolVersion)
require.NotNil(t, cl.Net.Conn)
require.NotNil(t, cl.Net.bconn)
require.NotNil(t, cl.ops)
require.Equal(t, s.Log, cl.ops.log)
}
func TestServerNewClientInline(t *testing.T) {
s := New(nil)
cl := s.NewClient(nil, "testing", "test", true)
require.True(t, cl.Net.Inline)
}
func TestServerAddHook(t *testing.T) {
s := New(nil)
s.Log = logger
require.NotNil(t, s)
require.Equal(t, int64(0), s.hooks.Len())
err := s.AddHook(new(HookBase), nil)
require.NoError(t, err)
require.Equal(t, int64(1), s.hooks.Len())
}
func TestServerAddListener(t *testing.T) {
s := newServer()
defer s.Close()
require.NotNil(t, s)
err := s.AddListener(listeners.NewMockListener("t1", ":1882"))
require.NoError(t, err)
// add existing listener
err = s.AddListener(listeners.NewMockListener("t1", ":1882"))
require.Error(t, err)
require.Equal(t, ErrListenerIDExists, err)
}
func TestServerAddHooksFromConfig(t *testing.T) {
s := newServer()
defer s.Close()
require.NotNil(t, s)
s.Log = logger
hooks := []HookLoadConfig{
{Hook: new(modifiedHookBase)},
}
err := s.AddHooksFromConfig(hooks)
require.NoError(t, err)
}
func TestServerAddHooksFromConfigError(t *testing.T) {
s := newServer()
defer s.Close()
require.NotNil(t, s)
s.Log = logger
hooks := []HookLoadConfig{
{Hook: new(modifiedHookBase), Config: map[string]interface{}{}},
}
err := s.AddHooksFromConfig(hooks)
require.Error(t, err)
}
func TestServerAddListenerInitFailure(t *testing.T) {
s := newServer()
defer s.Close()
require.NotNil(t, s)
m := listeners.NewMockListener("t1", ":1882")
m.ErrListen = true
err := s.AddListener(m)
require.Error(t, err)
}
func TestServerAddListenersFromConfig(t *testing.T) {
s := newServer()
defer s.Close()
require.NotNil(t, s)
s.Log = logger
lc := []listeners.Config{
{Type: listeners.TypeTCP, ID: "tcp", Address: ":1883"},
{Type: listeners.TypeWS, ID: "ws", Address: ":1882"},
{Type: listeners.TypeHealthCheck, ID: "health", Address: ":1881"},
{Type: listeners.TypeSysInfo, ID: "info", Address: ":1880"},
{Type: listeners.TypeUnix, ID: "unix", Address: "mochi.sock"},
{Type: listeners.TypeMock, ID: "mock", Address: "0"},
{Type: "unknown", ID: "unknown"},
}
err := s.AddListenersFromConfig(lc)
require.NoError(t, err)
require.Equal(t, 6, s.Listeners.Len())
tcp, _ := s.Listeners.Get("tcp")
require.Equal(t, "[::]:1883", tcp.Address())
ws, _ := s.Listeners.Get("ws")
require.Equal(t, ":1882", ws.Address())
health, _ := s.Listeners.Get("health")
require.Equal(t, ":1881", health.Address())
info, _ := s.Listeners.Get("info")
require.Equal(t, ":1880", info.Address())
unix, _ := s.Listeners.Get("unix")
require.Equal(t, "mochi.sock", unix.Address())
mock, _ := s.Listeners.Get("mock")
require.Equal(t, "0", mock.Address())
}
func TestServerAddListenersFromConfigError(t *testing.T) {
s := newServer()
defer s.Close()
require.NotNil(t, s)
s.Log = logger
lc := []listeners.Config{
{Type: listeners.TypeTCP, ID: "tcp", Address: "x"},
}
err := s.AddListenersFromConfig(lc)
require.Error(t, err)
require.Equal(t, 0, s.Listeners.Len())
}
func TestServerServe(t *testing.T) {
s := newServer()
defer s.Close()
require.NotNil(t, s)
err := s.AddListener(listeners.NewMockListener("t1", ":1882"))
require.NoError(t, err)
err = s.Serve()
require.NoError(t, err)
time.Sleep(time.Millisecond)
require.Equal(t, 1, s.Listeners.Len())
listener, ok := s.Listeners.Get("t1")
require.Equal(t, true, ok)
require.Equal(t, true, listener.(*listeners.MockListener).IsServing())
}
func TestServerServeFromConfig(t *testing.T) {
s := newServer()
defer s.Close()
require.NotNil(t, s)
s.Options.Listeners = []listeners.Config{
{Type: listeners.TypeMock, ID: "mock", Address: "0"},
}
s.Options.Hooks = []HookLoadConfig{
{Hook: new(modifiedHookBase)},
}
err := s.Serve()
require.NoError(t, err)
time.Sleep(time.Millisecond)
require.Equal(t, 1, s.Listeners.Len())
listener, ok := s.Listeners.Get("mock")
require.Equal(t, true, ok)
require.Equal(t, true, listener.(*listeners.MockListener).IsServing())
}
func TestServerServeFromConfigListenerError(t *testing.T) {
s := newServer()
defer s.Close()
require.NotNil(t, s)
s.Options.Listeners = []listeners.Config{
{Type: listeners.TypeTCP, ID: "tcp", Address: "x"},
}
err := s.Serve()
require.Error(t, err)
}
func TestServerServeFromConfigHookError(t *testing.T) {
s := newServer()
defer s.Close()
require.NotNil(t, s)
s.Options.Hooks = []HookLoadConfig{
{Hook: new(modifiedHookBase), Config: map[string]interface{}{}},
}
err := s.Serve()
require.Error(t, err)
}
func TestServerServeReadStoreFailure(t *testing.T) {
s := newServer()
defer s.Close()
require.NotNil(t, s)
err := s.AddListener(listeners.NewMockListener("t1", ":1882"))
require.NoError(t, err)
hook := new(modifiedHookBase)
hook.failAt = 1
err = s.AddHook(hook, nil)
require.NoError(t, err)
err = s.Serve()
require.Error(t, err)
}
func TestServerEventLoop(t *testing.T) {
s := newServer()
defer s.Close()
s.loop.sysTopics = time.NewTicker(time.Millisecond)
s.loop.inflightExpiry = time.NewTicker(time.Millisecond)
s.loop.clientExpiry = time.NewTicker(time.Millisecond)
s.loop.retainedExpiry = time.NewTicker(time.Millisecond)
s.loop.willDelaySend = time.NewTicker(time.Millisecond)
go s.eventLoop()
time.Sleep(time.Millisecond * 3)
}
func TestServerReadConnectionPacket(t *testing.T) {
s := newServer()
defer s.Close()
cl, r, _ := newTestClient()
s.Clients.Add(cl)
o := make(chan packets.Packet)
go func() {
pk, err := s.readConnectionPacket(cl)
require.NoError(t, err)
o <- pk
}()
go func() {
_, _ = r.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).RawBytes)
_ = r.Close()
}()
require.Equal(t, *packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).Packet, <-o)
}
func TestServerReadConnectionPacketBadFixedHeader(t *testing.T) {
s := newServer()
defer s.Close()
cl, r, _ := newTestClient()
s.Clients.Add(cl)
o := make(chan error)
go func() {
_, err := s.readConnectionPacket(cl)
o <- err
}()
go func() {
_, _ = r.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMalFixedHeader).RawBytes)
_ = r.Close()
}()
err := <-o
require.Error(t, err)
require.Equal(t, packets.ErrMalformedVariableByteInteger, err)
}
func TestServerReadConnectionPacketBadPacketType(t *testing.T) {
s := newServer()
defer s.Close()
cl, r, _ := newTestClient()
s.Clients.Add(cl)
go func() {
_, _ = r.Write(packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedNoSession).RawBytes)
_ = r.Close()
}()
_, err := s.readConnectionPacket(cl)
require.Error(t, err)
require.Equal(t, packets.ErrProtocolViolationRequireFirstConnect, err)
}
func TestServerReadConnectionPacketBadPacket(t *testing.T) {
s := newServer()
defer s.Close()
cl, r, _ := newTestClient()
s.Clients.Add(cl)
go func() {
_, _ = r.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMalProtocolName).RawBytes)
_ = r.Close()
}()
_, err := s.readConnectionPacket(cl)
require.Error(t, err)
require.ErrorIs(t, err, packets.ErrMalformedProtocolName)
}
func TestEstablishConnection(t *testing.T) {
s := newServer()
defer s.Close()
r, w := net.Pipe()
o := make(chan error)
go func() {
o <- s.EstablishConnection("tcp", r)
}()
go func() {
_, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectClean).RawBytes)
_, _ = w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes)
}()
// receive the connack
recv := make(chan []byte)
go func() {
buf, err := io.ReadAll(w)
require.NoError(t, err)
recv <- buf
}()
err := <-o
require.NoError(t, err)
// Todo:
// s.Clients is already empty here. Is it necessary to check v.StopCause()?
// for _, v := range s.Clients.GetAll() {
// require.ErrorIs(t, v.StopCause(), packets.CodeDisconnect) // true error is disconnect
// }
require.Equal(t, packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedNoSession).RawBytes, <-recv)
_ = w.Close()
_ = r.Close()
// client must be deleted on session close if Clean = true
_, ok := s.Clients.Get(packets.TPacketData[packets.Connect].Get(packets.TConnectClean).Packet.Connect.ClientIdentifier)
require.False(t, ok)
}
func TestEstablishConnectionAckFailure(t *testing.T) {
s := newServer()
defer s.Close()
r, w := net.Pipe()
o := make(chan error)
go func() {
o <- s.EstablishConnection("tcp", r)
}()
go func() {
_, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectClean).RawBytes)
_ = w.Close()
}()
err := <-o
require.Error(t, err)
require.ErrorIs(t, err, io.ErrClosedPipe)
_ = r.Close()
}
func TestEstablishConnectionReadError(t *testing.T) {
s := newServer()
defer s.Close()
r, w := net.Pipe()
o := make(chan error)
go func() {
o <- s.EstablishConnection("tcp", r)
}()
go func() {
_, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt5).RawBytes)
_, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectClean).RawBytes) // second connect error
}()
// receive the connack
recv := make(chan []byte)
go func() {
buf, err := io.ReadAll(w)
require.NoError(t, err)
recv <- buf
}()
err := <-o
require.Error(t, err)
// Retrieve the client corresponding to the Client Identifier.
retrievedCl, ok := s.Clients.Get(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt5).Packet.Connect.ClientIdentifier)
require.True(t, ok)
require.ErrorIs(t, retrievedCl.StopCause(), packets.ErrProtocolViolationSecondConnect) // true error is disconnect
ret := <-recv
require.Equal(t, append(
packets.TPacketData[packets.Connack].Get(packets.TConnackMinCleanMqtt5).RawBytes,
packets.TPacketData[packets.Disconnect].Get(packets.TDisconnectSecondConnect).RawBytes...),
ret,
)
_ = w.Close()
_ = r.Close()
}
func TestEstablishConnectionInheritExisting(t *testing.T) {
s := newServer()
defer s.Close()
cl, r0, _ := newTestClient()
cl.Properties.ProtocolVersion = 5
cl.Properties.Username = []byte("mochi")
cl.ID = packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).Packet.Connect.ClientIdentifier
cl.State.Subscriptions.Add("a/b/c", packets.Subscription{Filter: "a/b/c", Qos: 1})
cl.State.Inflight.Set(*packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet)
s.Clients.Add(cl)
r, w := net.Pipe()
o := make(chan error)
go func() {
err := s.EstablishConnection("tcp", r)
o <- err
}()
go func() {
_, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).RawBytes)
time.Sleep(time.Millisecond) // we want to receive the queued inflight, so we need to wait a moment before sending the disconnect.
_, _ = w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes)
}()
// receive the disconnect session takeover
takeover := make(chan []byte)
go func() {
buf, err := io.ReadAll(r0)
require.NoError(t, err)
takeover <- buf
}()
// receive the connack
recv := make(chan []byte)
go func() {
buf, err := io.ReadAll(w)
require.NoError(t, err)
recv <- buf
}()
err := <-o
require.NoError(t, err)
// Retrieve the client corresponding to the Client Identifier.
retrievedCl, ok := s.Clients.Get(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).Packet.Connect.ClientIdentifier)
require.True(t, ok)
require.ErrorIs(t, retrievedCl.StopCause(), packets.CodeDisconnect) // true error is disconnect
connackPlusPacket := append(
packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedSessionExists).RawBytes,
packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Dup).RawBytes...,
)
require.Equal(t, connackPlusPacket, <-recv)
require.Equal(t, packets.TPacketData[packets.Disconnect].Get(packets.TDisconnectTakeover).RawBytes, <-takeover)
time.Sleep(time.Microsecond * 100)
_ = w.Close()
_ = r.Close()
clw, ok := s.Clients.Get(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).Packet.Connect.ClientIdentifier)
require.True(t, ok)
require.NotEmpty(t, clw.State.Subscriptions)
// Prevent sequential takeover memory-bloom.
require.Empty(t, cl.State.Subscriptions.GetAll())
}
// See https://github.com/mochi-mqtt/server/issues/173
func TestEstablishConnectionInheritExistingTrueTakeover(t *testing.T) {
s := newServer()
d := new(DelayHook)
d.DisconnectDelay = time.Millisecond * 200
_ = s.AddHook(d, nil)
defer s.Close()
// Clean session, 0 session expiry interval
cl1RawBytes := []byte{
packets.Connect << 4, 21, // Fixed header
0, 4, // Protocol Name - MSB+LSB
'M', 'Q', 'T', 'T', // Protocol Name
5, // Protocol Version
1 << 1, // Packet Flags
0, 30, // Keepalive
5, // Properties length
17, 0, 0, 0, 0, // Session Expiry Interval (17)
0, 3, // Client ID - MSB+LSB
'z', 'e', 'n', // Client ID "zen"
}
// Make first connection
r1, w1 := net.Pipe()
o1 := make(chan error)
go func() {
err := s.EstablishConnection("tcp", r1)
o1 <- err
}()
go func() {
_, _ = w1.Write(cl1RawBytes)
}()
// receive the first connack
recv := make(chan []byte)
go func() {
buf, err := io.ReadAll(w1)
require.NoError(t, err)
recv <- buf
}()
// Get the first client pointer
time.Sleep(time.Millisecond * 50)
cl1, ok := s.Clients.Get(packets.TPacketData[packets.Connect].Get(packets.TConnectUserPass).Packet.Connect.ClientIdentifier)
require.True(t, ok)
cl1.State.Subscriptions.Add("a/b/c", packets.Subscription{Filter: "a/b/c", Qos: 1})
cl1.State.Subscriptions.Add("d/e/f", packets.Subscription{Filter: "d/e/f", Qos: 0})
time.Sleep(time.Millisecond * 50)
// Make the second connection
r2, w2 := net.Pipe()
o2 := make(chan error)
go func() {
err := s.EstablishConnection("tcp", r2)
o2 <- err
}()
go func() {
x := packets.TPacketData[packets.Connect].Get(packets.TConnectUserPass).RawBytes[:]
x[19] = '.' // differentiate username bytes in debugging
_, _ = w2.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectUserPass).RawBytes)
}()
// receive the second connack
recv2 := make(chan []byte)
go func() {
buf, err := io.ReadAll(w2)
require.NoError(t, err)
recv2 <- buf
}()
// Capture first Client pointer
clp1, ok := s.Clients.Get("zen")
require.True(t, ok)
require.Empty(t, clp1.Properties.Username)
require.NotEmpty(t, clp1.State.Subscriptions.GetAll())
err1 := <-o1
require.Error(t, err1)
require.ErrorIs(t, err1, io.ErrClosedPipe)
// Capture second Client pointer
clp2, ok := s.Clients.Get("zen")
require.True(t, ok)
require.Equal(t, []byte(".ochi"), clp2.Properties.Username)
require.NotEmpty(t, clp2.State.Subscriptions.GetAll())
require.Empty(t, clp1.State.Subscriptions.GetAll())
_, _ = w2.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes)
require.NoError(t, <-o2)
}
func TestEstablishConnectionResentPendingInflightsError(t *testing.T) {
s := newServer()
defer s.Close()
n := time.Now().Unix()
cl, r0, _ := newTestClient()
cl.Properties.ProtocolVersion = 5
cl.ID = packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).Packet.Connect.ClientIdentifier
cl.State.Inflight = NewInflights()
cl.State.Inflight.Set(packets.Packet{PacketID: 2, Created: n - 2}) // no packet type
s.Clients.Add(cl)
r, w := net.Pipe()
o := make(chan error)
go func() {
o <- s.EstablishConnection("tcp", r)
}()
go func() {
_, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).RawBytes)
}()
go func() {
_, err := io.ReadAll(r0)
require.NoError(t, err)
}()
go func() {
_, err := io.ReadAll(w)
require.NoError(t, err)
}()
err := <-o
require.Error(t, err)
require.ErrorIs(t, err, packets.ErrNoValidPacketAvailable)
}
func TestEstablishConnectionInheritExistingClean(t *testing.T) {
s := newServer()
defer s.Close()
cl, r0, _ := newTestClient()
cl.ID = packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).Packet.Connect.ClientIdentifier
cl.Properties.Clean = true
cl.State.Subscriptions.Add("a/b/c", packets.Subscription{Filter: "a/b/c", Qos: 1})
s.Clients.Add(cl)
r, w := net.Pipe()
o := make(chan error)
go func() {
o <- s.EstablishConnection("tcp", r)
}()
go func() {
_, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).RawBytes)
_, _ = w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes)
}()
// receive the disconnect
takeover := make(chan []byte)
go func() {
buf, err := io.ReadAll(r0)
require.NoError(t, err)
takeover <- buf
}()
// receive the connack
recv := make(chan []byte)
go func() {
buf, err := io.ReadAll(w)
require.NoError(t, err)
recv <- buf
}()
err := <-o
require.NoError(t, err)
// Retrieve the client corresponding to the Client Identifier.
retrievedCl, ok := s.Clients.Get(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).Packet.Connect.ClientIdentifier)
require.True(t, ok)
require.ErrorIs(t, retrievedCl.StopCause(), packets.CodeDisconnect) // true error is disconnect
require.Equal(t, packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedNoSession).RawBytes, <-recv)
require.Equal(t, packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes, <-takeover)
_ = w.Close()
_ = r.Close()
clw, ok := s.Clients.Get(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).Packet.Connect.ClientIdentifier)
require.True(t, ok)
require.Equal(t, 0, clw.State.Subscriptions.Len())
}
func TestEstablishConnectionBadAuthentication(t *testing.T) {
s := New(&Options{
Logger: logger,
})
defer s.Close()
r, w := net.Pipe()
o := make(chan error)
go func() {
o <- s.EstablishConnection("tcp", r)
}()
go func() {
_, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectClean).RawBytes)
_, _ = w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes)
}()
// receive the connack
recv := make(chan []byte)
go func() {
buf, err := io.ReadAll(w)
require.NoError(t, err)
recv <- buf
}()
err := <-o
require.Error(t, err)
require.ErrorIs(t, err, packets.ErrBadUsernameOrPassword)
require.Equal(t, packets.TPacketData[packets.Connack].Get(packets.TConnackBadUsernamePasswordNoSession).RawBytes, <-recv)
_ = w.Close()
_ = r.Close()
}
func TestEstablishConnectionBadAuthenticationAckFailure(t *testing.T) {
s := New(&Options{
Logger: logger,
})
defer s.Close()
r, w := net.Pipe()
o := make(chan error)
go func() {
o <- s.EstablishConnection("tcp", r)
}()
go func() {
_, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectClean).RawBytes)
_ = w.Close()
}()
err := <-o
require.Error(t, err)
require.ErrorIs(t, err, io.ErrClosedPipe)
_ = r.Close()
}
func TestServerEstablishConnectionInvalidConnect(t *testing.T) {
s := newServer()
r, w := net.Pipe()
o := make(chan error)
go func() {
o <- s.EstablishConnection("tcp", r)
}()
go func() {
_, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMalReservedBit).RawBytes)
_, _ = w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes)
}()
// receive the connack
recv := make(chan []byte)
go func() {
buf, err := io.ReadAll(w)
require.NoError(t, err)
recv <- buf
}()
err := <-o
require.Error(t, err)
require.ErrorIs(t, packets.ErrProtocolViolationReservedBit, err)
require.Equal(t, packets.TPacketData[packets.Connack].Get(packets.TConnackProtocolViolationNoSession).RawBytes, <-recv)
_ = r.Close()
}
// See https://github.com/mochi-mqtt/server/issues/178
func TestServerEstablishConnectionZeroByteUsernameIsValid(t *testing.T) {
s := newServer()
r, w := net.Pipe()
o := make(chan error)
go func() {
o <- s.EstablishConnection("tcp", r)
}()
go func() {
_, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectZeroByteUsername).RawBytes)
_, _ = w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes)
}()
// receive the connack error
go func() {
_, err := io.ReadAll(w)
require.NoError(t, err)
}()
err := <-o
require.NoError(t, err)
_ = r.Close()
}
func TestServerEstablishConnectionInvalidConnectAckFailure(t *testing.T) {
s := newServer()
r, w := net.Pipe()
o := make(chan error)
go func() {
o <- s.EstablishConnection("tcp", r)
}()
go func() {
_, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectMalReservedBit).RawBytes)
_ = w.Close()
}()
err := <-o
require.Error(t, err)
require.ErrorIs(t, err, io.ErrClosedPipe)
_ = r.Close()
}
func TestServerEstablishConnectionBadPacket(t *testing.T) {
s := newServer()
r, w := net.Pipe()
o := make(chan error)
go func() {
o <- s.EstablishConnection("tcp", r)
}()
go func() {
_, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnackBadProtocolVersion).RawBytes)
_, _ = w.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes)
}()
err := <-o
require.Error(t, err)
require.ErrorIs(t, err, packets.ErrProtocolViolationRequireFirstConnect)
_ = r.Close()
}
func TestServerEstablishConnectionOnConnectError(t *testing.T) {
s := newServer()
hook := new(modifiedHookBase)
hook.fail = true
err := s.AddHook(hook, nil)
require.NoError(t, err)
r, w := net.Pipe()
o := make(chan error)
go func() {
o <- s.EstablishConnection("tcp", r)
}()
go func() {
_, _ = w.Write(packets.TPacketData[packets.Connect].Get(packets.TConnectClean).RawBytes)
}()
err = <-o
require.Error(t, err)
require.ErrorIs(t, err, errTestHook)
_ = r.Close()
}
func TestServerSendConnack(t *testing.T) {
s := newServer()
cl, r, w := newTestClient()
cl.Properties.ProtocolVersion = 5
s.Options.Capabilities.MaximumQos = 1
cl.Properties.Props = packets.Properties{
AssignedClientID: "mochi",
}
go func() {
err := s.SendConnack(cl, packets.CodeSuccess, true, nil)
require.NoError(t, err)
_ = w.Close()
}()
buf, err := io.ReadAll(r)
require.NoError(t, err)
require.Equal(t, packets.TPacketData[packets.Connack].Get(packets.TConnackMinMqtt5).RawBytes, buf)
}
func TestServerSendConnackFailureReason(t *testing.T) {
s := newServer()
cl, r, w := newTestClient()
cl.Properties.ProtocolVersion = 5
go func() {
err := s.SendConnack(cl, packets.ErrUnspecifiedError, true, nil)
require.NoError(t, err)
_ = w.Close()
}()
buf, err := io.ReadAll(r)
require.NoError(t, err)
require.Equal(t, packets.TPacketData[packets.Connack].Get(packets.TConnackInvalidMinMqtt5).RawBytes, buf)
}
func TestServerSendConnackWithServerKeepalive(t *testing.T) {
s := newServer()
cl, r, w := newTestClient()
cl.Properties.ProtocolVersion = 5
cl.State.Keepalive = 10
cl.State.ServerKeepalive = true
go func() {
err := s.SendConnack(cl, packets.CodeSuccess, true, nil)
require.NoError(t, err)
_ = w.Close()
}()
buf, err := io.ReadAll(r)
require.NoError(t, err)
require.Equal(t, packets.TPacketData[packets.Connack].Get(packets.TConnackServerKeepalive).RawBytes, buf)
}
func TestServerValidateConnect(t *testing.T) {
packet := *packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt5).Packet
invalidBitPacket := packet
invalidBitPacket.ReservedBit = 1
packetCleanIdPacket := packet
packetCleanIdPacket.Connect.Clean = false
packetCleanIdPacket.Connect.ClientIdentifier = ""
tt := []struct {
desc string
client *Client
capabilities Capabilities
packet packets.Packet
expect packets.Code
}{
{
desc: "unsupported protocol version",
client: &Client{Properties: ClientProperties{ProtocolVersion: 3}},
capabilities: Capabilities{MinimumProtocolVersion: 4},
packet: packet,
expect: packets.ErrUnsupportedProtocolVersion,
},
{
desc: "will qos not supported",
client: &Client{Properties: ClientProperties{Will: Will{Qos: 2}}},
capabilities: Capabilities{MaximumQos: 1},
packet: packet,
expect: packets.ErrQosNotSupported,
},
{
desc: "retain not supported",
client: &Client{Properties: ClientProperties{Will: Will{Retain: true}}},
capabilities: Capabilities{RetainAvailable: 0},
packet: packet,
expect: packets.ErrRetainNotSupported,
},
{
desc: "invalid packet validate",
client: &Client{Properties: ClientProperties{Will: Will{Retain: true}}},
capabilities: Capabilities{RetainAvailable: 0},
packet: invalidBitPacket,
expect: packets.ErrProtocolViolationReservedBit,
},
{
desc: "mqtt3 clean no client id ",
client: &Client{Properties: ClientProperties{ProtocolVersion: 3}},
capabilities: Capabilities{},
packet: packetCleanIdPacket,
expect: packets.ErrUnspecifiedError,
},
}
s := newServer()
for _, tx := range tt {
t.Run(tx.desc, func(t *testing.T) {
s.Options.Capabilities = &tx.capabilities
err := s.validateConnect(tx.client, tx.packet)
require.Error(t, err)
require.ErrorIs(t, err, tx.expect)
})
}
}
func TestServerSendConnackAdjustedExpiryInterval(t *testing.T) {
s := newServer()
cl, r, w := newTestClient()
cl.Properties.ProtocolVersion = 5
cl.Properties.Props.SessionExpiryInterval = uint32(300)
s.Options.Capabilities.MaximumSessionExpiryInterval = 120
go func() {
err := s.SendConnack(cl, packets.CodeSuccess, false, nil)
require.NoError(t, err)
_ = w.Close()
}()
buf, err := io.ReadAll(r)
require.NoError(t, err)
require.Equal(t, packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedAdjustedExpiryInterval).RawBytes, buf)
}
func TestInheritClientSession(t *testing.T) {
s := newServer()
n := time.Now().Unix()
existing, _, _ := newTestClient()
existing.Net.Conn = nil
existing.ID = "mochi"
existing.State.Subscriptions.Add("a/b/c", packets.Subscription{Filter: "a/b/c", Qos: 1})
existing.State.Inflight = NewInflights()
existing.State.Inflight.Set(packets.Packet{PacketID: 1, Created: n - 1})
existing.State.Inflight.Set(packets.Packet{PacketID: 2, Created: n - 2})
s.Clients.Add(existing)
cl, _, _ := newTestClient()
cl.Properties.ProtocolVersion = 5
require.Equal(t, 0, cl.State.Inflight.Len())
require.Equal(t, 0, cl.State.Subscriptions.Len())
// Inherit existing client properties
b := s.inheritClientSession(packets.Packet{Connect: packets.ConnectParams{ClientIdentifier: "mochi"}}, cl)
require.True(t, b)
require.Equal(t, 2, cl.State.Inflight.Len())
require.Equal(t, 1, cl.State.Subscriptions.Len())
// On clean, clear existing properties
cl, _, _ = newTestClient()
cl.Properties.ProtocolVersion = 5
b = s.inheritClientSession(packets.Packet{Connect: packets.ConnectParams{ClientIdentifier: "mochi", Clean: true}}, cl)
require.False(t, b)
require.Equal(t, 0, cl.State.Inflight.Len())
require.Equal(t, 0, cl.State.Subscriptions.Len())
}
func TestServerUnsubscribeClient(t *testing.T) {
s := newServer()
cl, _, _ := newTestClient()
pk := packets.Subscription{Filter: "a/b/c", Qos: 1}
cl.State.Subscriptions.Add("a/b/c", pk)
s.Topics.Subscribe(cl.ID, pk)
subs := s.Topics.Subscribers("a/b/c")
require.Equal(t, 1, len(subs.Subscriptions))
s.UnsubscribeClient(cl)
subs = s.Topics.Subscribers("a/b/c")
require.Equal(t, 0, len(subs.Subscriptions))
}
func TestServerProcessPacketFailure(t *testing.T) {
s := newServer()
cl, _, _ := newTestClient()
err := s.processPacket(cl, packets.Packet{})
require.Error(t, err)
}
func TestServerProcessPacketConnect(t *testing.T) {
s := newServer()
cl, _, _ := newTestClient()
err := s.processPacket(cl, *packets.TPacketData[packets.Connect].Get(packets.TConnectClean).Packet)
require.Error(t, err)
}
func TestServerProcessPacketPingreq(t *testing.T) {
s := newServer()
cl, r, w := newTestClient()
go func() {
err := s.processPacket(cl, *packets.TPacketData[packets.Pingreq].Get(packets.TPingreq).Packet)
require.NoError(t, err)
_ = w.Close()
}()
buf, err := io.ReadAll(r)
require.NoError(t, err)
require.Equal(t, packets.TPacketData[packets.Pingresp].Get(packets.TPingresp).RawBytes, buf)
}
func TestServerProcessPacketPingreqError(t *testing.T) {
s := newServer()
cl, _, _ := newTestClient()
cl.Stop(packets.CodeDisconnect)
err := s.processPacket(cl, *packets.TPacketData[packets.Pingreq].Get(packets.TPingreq).Packet)
require.Error(t, err)
require.ErrorIs(t, cl.StopCause(), packets.CodeDisconnect)
}
func TestServerProcessPacketPublishInvalid(t *testing.T) {
s := newServer()
cl, _, _ := newTestClient()
err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishInvalidQosMustPacketID).Packet)
require.Error(t, err)
require.ErrorIs(t, err, packets.ErrProtocolViolationNoPacketID)
}
func TestInjectPacketPublishAndReceive(t *testing.T) {
s := newServer()
_ = s.Serve()
defer s.Close()
sender, _, w1 := newTestClient()
sender.Net.Inline = true
sender.ID = "sender"
s.Clients.Add(sender)
receiver, r2, w2 := newTestClient()
receiver.ID = "receiver"
s.Clients.Add(receiver)
s.Topics.Subscribe(receiver.ID, packets.Subscription{Filter: "a/b/c"})
require.Equal(t, int64(0), atomic.LoadInt64(&s.Info.PacketsReceived))
receiverBuf := make(chan []byte)
go func() {
buf, err := io.ReadAll(r2)
require.NoError(t, err)
receiverBuf <- buf
}()
go func() {
err := s.InjectPacket(sender, *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet)
require.NoError(t, err)
_ = w1.Close()
time.Sleep(time.Millisecond * 10)
_ = w2.Close()
}()
require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes, <-receiverBuf)
}
func TestServerPublishAndReceive(t *testing.T) {
s := newServerWithInlineClient()
_ = s.Serve()
defer s.Close()
sender, _, w1 := newTestClient()
sender.Net.Inline = true
sender.ID = "sender"
s.Clients.Add(sender)
receiver, r2, w2 := newTestClient()
receiver.ID = "receiver"
s.Clients.Add(receiver)
s.Topics.Subscribe(receiver.ID, packets.Subscription{Filter: "a/b/c"})
require.Equal(t, int64(0), atomic.LoadInt64(&s.Info.PacketsReceived))
receiverBuf := make(chan []byte)
go func() {
buf, err := io.ReadAll(r2)
require.NoError(t, err)
receiverBuf <- buf
}()
go func() {
pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet
err := s.Publish(pkx.TopicName, pkx.Payload, pkx.FixedHeader.Retain, pkx.FixedHeader.Qos)
require.NoError(t, err)
_ = w1.Close()
time.Sleep(time.Millisecond * 10)
_ = w2.Close()
}()
require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes, <-receiverBuf)
}
func TestServerPublishNoInlineClient(t *testing.T) {
s := newServer()
pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet
err := s.Publish(pkx.TopicName, pkx.Payload, pkx.FixedHeader.Retain, pkx.FixedHeader.Qos)
require.Error(t, err)
require.ErrorIs(t, err, ErrInlineClientNotEnabled)
}
func TestInjectPacketError(t *testing.T) {
s := newServer()
defer s.Close()
cl, _, _ := newTestClient()
cl.Net.Inline = true
pkx := *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribe).Packet
pkx.Filters = packets.Subscriptions{}
err := s.InjectPacket(cl, pkx)
require.Error(t, err)
}
func TestInjectPacketPublishInvalidTopic(t *testing.T) {
s := newServer()
defer s.Close()
cl, _, _ := newTestClient()
cl.Net.Inline = true
pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet
pkx.TopicName = "$SYS/test"
err := s.InjectPacket(cl, pkx)
require.NoError(t, err) // bypass topic validity and acl checks
}
func TestServerProcessPacketPublishAndReceive(t *testing.T) {
s := newServer()
_ = s.Serve()
defer s.Close()
sender, _, w1 := newTestClient()
sender.ID = "sender"
s.Clients.Add(sender)
receiver, r2, w2 := newTestClient()
receiver.ID = "receiver"
s.Clients.Add(receiver)
s.Topics.Subscribe(receiver.ID, packets.Subscription{Filter: "a/b/c"})
require.Equal(t, int64(0), atomic.LoadInt64(&s.Info.PacketsReceived))
require.Equal(t, 0, len(s.Topics.Messages("a/b/c")))
receiverBuf := make(chan []byte)
go func() {
buf, err := io.ReadAll(r2)
require.NoError(t, err)
receiverBuf <- buf
}()
go func() {
err := s.processPacket(sender, *packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet)
require.NoError(t, err)
time.Sleep(time.Millisecond * 10)
_ = w1.Close()
_ = w2.Close()
}()
require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes, <-receiverBuf)
require.Equal(t, 1, len(s.Topics.Messages("a/b/c")))
}
func TestServerBuildAck(t *testing.T) {
s := newServer()
properties := packets.Properties{
User: []packets.UserProperty{
{Key: "hello", Val: "世界"},
},
}
ack := s.buildAck(7, packets.Puback, 1, properties, packets.CodeGrantedQos1)
require.Equal(t, packets.Puback, ack.FixedHeader.Type)
require.Equal(t, uint8(1), ack.FixedHeader.Qos)
require.Equal(t, packets.CodeGrantedQos1.Code, ack.ReasonCode)
require.Equal(t, properties, ack.Properties)
}
func TestServerBuildAckError(t *testing.T) {
s := newServer()
properties := packets.Properties{
User: []packets.UserProperty{
{Key: "hello", Val: "世界"},
},
}
ack := s.buildAck(7, packets.Puback, 1, properties, packets.ErrMalformedPacket)
require.Equal(t, packets.Puback, ack.FixedHeader.Type)
require.Equal(t, uint8(1), ack.FixedHeader.Qos)
require.Equal(t, packets.ErrMalformedPacket.Code, ack.ReasonCode)
properties.ReasonString = packets.ErrMalformedPacket.Reason
require.Equal(t, properties, ack.Properties)
}
func TestServerBuildAckPahoCompatibility(t *testing.T) {
s := newServer()
s.Options.Capabilities.Compatibilities.NoInheritedPropertiesOnAck = true
properties := packets.Properties{
User: []packets.UserProperty{
{Key: "hello", Val: "世界"},
},
}
ack := s.buildAck(7, packets.Puback, 1, properties, packets.CodeGrantedQos1)
require.Equal(t, packets.Puback, ack.FixedHeader.Type)
require.Equal(t, uint8(1), ack.FixedHeader.Qos)
require.Equal(t, packets.CodeGrantedQos1.Code, ack.ReasonCode)
require.Equal(t, packets.Properties{}, ack.Properties)
}
func TestServerProcessPacketAndNextImmediate(t *testing.T) {
s := newServer()
cl, r, w := newTestClient()
next := *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet
next.Expiry = -1
cl.State.Inflight.Set(next)
atomic.StoreInt64(&s.Info.Inflight, 1)
require.Equal(t, int64(1), atomic.LoadInt64(&s.Info.Inflight))
require.Equal(t, int32(5), cl.State.Inflight.sendQuota)
go func() {
err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet)
require.NoError(t, err)
_ = w.Close()
}()
buf, err := io.ReadAll(r)
require.NoError(t, err)
require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).RawBytes, buf)
require.Equal(t, int64(0), atomic.LoadInt64(&s.Info.Inflight))
require.Equal(t, int32(4), cl.State.Inflight.sendQuota)
}
func TestServerProcessPublishAckFailure(t *testing.T) {
s := newServer()
_ = s.Serve()
defer s.Close()
cl, _, w := newTestClient()
s.Clients.Add(cl)
_ = w.Close()
err := s.processPublish(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos2).Packet)
require.Error(t, err)
require.ErrorIs(t, err, io.ErrClosedPipe)
}
func TestServerProcessPublishOnPublishAckErrorRWError(t *testing.T) {
s := newServer()
hook := new(modifiedHookBase)
hook.fail = true
hook.err = packets.ErrUnspecifiedError
err := s.AddHook(hook, nil)
require.NoError(t, err)
cl, _, w := newTestClient()
cl.Properties.ProtocolVersion = 5
s.Clients.Add(cl)
_ = w.Close()
err = s.processPublish(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet)
require.Error(t, err)
require.ErrorIs(t, err, io.ErrClosedPipe)
}
func TestServerProcessPublishOnPublishAckErrorContinue(t *testing.T) {
s := newServer()
hook := new(modifiedHookBase)
hook.fail = true
hook.err = packets.ErrPayloadFormatInvalid
err := s.AddHook(hook, nil)
require.NoError(t, err)
_ = s.Serve()
defer s.Close()
cl, r, w := newTestClient()
cl.Properties.ProtocolVersion = 5
s.Clients.Add(cl)
go func() {
err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet)
require.NoError(t, err)
_ = w.Close()
}()
buf, err := io.ReadAll(r)
require.NoError(t, err)
require.Equal(t, packets.TPacketData[packets.Puback].Get(packets.TPubackUnexpectedError).RawBytes, buf)
}
func TestServerProcessPublishOnPublishPkIgnore(t *testing.T) {
s := newServer()
hook := new(modifiedHookBase)
hook.fail = true
hook.err = packets.CodeSuccessIgnore
err := s.AddHook(hook, nil)
require.NoError(t, err)
_ = s.Serve()
defer s.Close()
cl, r, w := newTestClient()
s.Clients.Add(cl)
receiver, r2, w2 := newTestClient()
receiver.ID = "receiver"
s.Clients.Add(receiver)
s.Topics.Subscribe(receiver.ID, packets.Subscription{Filter: "a/b/c"})
require.Equal(t, int64(0), atomic.LoadInt64(&s.Info.PacketsReceived))
require.Equal(t, 0, len(s.Topics.Messages("a/b/c")))
receiverBuf := make(chan []byte)
go func() {
buf, err := io.ReadAll(r2)
require.NoError(t, err)
receiverBuf <- buf
}()
go func() {
err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet)
require.NoError(t, err)
_ = w.Close()
_ = w2.Close()
}()
buf, err := io.ReadAll(r)
require.NoError(t, err)
require.Equal(t, packets.TPacketData[packets.Puback].Get(packets.TPuback).RawBytes, buf)
require.Equal(t, []byte{}, <-receiverBuf)
require.Equal(t, 0, len(s.Topics.Messages("a/b/c")))
}
func TestServerProcessPacketPublishMaximumReceive(t *testing.T) {
s := newServer()
_ = s.Serve()
defer s.Close()
cl, r, w := newTestClient()
cl.Properties.ProtocolVersion = 5
cl.State.Inflight.ResetReceiveQuota(0)
s.Clients.Add(cl)
go func() {
err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet)
require.Error(t, err)
require.ErrorIs(t, err, packets.ErrReceiveMaximum)
_ = w.Close()
}()
buf, err := io.ReadAll(r)
require.NoError(t, err)
require.Equal(t, packets.TPacketData[packets.Disconnect].Get(packets.TDisconnectReceiveMaximum).RawBytes, buf)
}
func TestServerProcessPublishInvalidTopic(t *testing.T) {
s := newServer()
_ = s.Serve()
defer s.Close()
cl, _, _ := newTestClient()
err := s.processPublish(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishSpecDenySysTopic).Packet)
require.NoError(t, err) // $SYS Topics should be ignored?
}
func TestServerProcessPublishACLCheckDeny(t *testing.T) {
tt := []struct {
name string
protocolVersion byte
pk packets.Packet
expectErr error
expectReponse []byte
expectDisconnect bool
}{
{
name: "v4_QOS0",
protocolVersion: 4,
pk: *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet,
expectErr: nil,
expectReponse: nil,
expectDisconnect: false,
},
{
name: "v4_QOS1",
protocolVersion: 4,
pk: *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet,
expectErr: packets.ErrNotAuthorized,
expectReponse: nil,
expectDisconnect: true,
},
{
name: "v4_QOS2",
protocolVersion: 4,
pk: *packets.TPacketData[packets.Publish].Get(packets.TPublishQos2).Packet,
expectErr: packets.ErrNotAuthorized,
expectReponse: nil,
expectDisconnect: true,
},
{
name: "v5_QOS0",
protocolVersion: 5,
pk: *packets.TPacketData[packets.Publish].Get(packets.TPublishBasicMqtt5).Packet,
expectErr: nil,
expectReponse: nil,
expectDisconnect: false,
},
{
name: "v5_QOS1",
protocolVersion: 5,
pk: *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1Mqtt5).Packet,
expectErr: nil,
expectReponse: packets.TPacketData[packets.Puback].Get(packets.TPubrecMqtt5NotAuthorized).RawBytes,
expectDisconnect: false,
},
{
name: "v5_QOS2",
protocolVersion: 5,
pk: *packets.TPacketData[packets.Publish].Get(packets.TPublishQos2Mqtt5).Packet,
expectErr: nil,
expectReponse: packets.TPacketData[packets.Pubrec].Get(packets.TPubrecMqtt5NotAuthorized).RawBytes,
expectDisconnect: false,
},
}
for _, tx := range tt {
t.Run(tx.name, func(t *testing.T) {
cc := NewDefaultServerCapabilities()
s := New(&Options{
Logger: logger,
Capabilities: cc,
})
_ = s.AddHook(new(DenyHook), nil)
_ = s.Serve()
defer s.Close()
cl, r, w := newTestClient()
cl.Properties.ProtocolVersion = tx.protocolVersion
s.Clients.Add(cl)
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
err := s.processPublish(cl, tx.pk)
require.ErrorIs(t, err, tx.expectErr)
_ = w.Close()
}()
buf, err := io.ReadAll(r)
require.NoError(t, err)
if tx.expectReponse != nil {
require.Equal(t, tx.expectReponse, buf)
}
require.Equal(t, tx.expectDisconnect, cl.Closed())
wg.Wait()
})
}
}
func TestServerProcessPublishOnMessageRecvRejected(t *testing.T) {
s := newServer()
require.NotNil(t, s)
hook := new(modifiedHookBase)
hook.fail = true
hook.err = packets.ErrRejectPacket
err := s.AddHook(hook, nil)
require.NoError(t, err)
_ = s.Serve()
defer s.Close()
cl, _, _ := newTestClient()
err = s.processPublish(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet)
require.NoError(t, err) // packets rejected silently
}
func TestServerProcessPacketPublishQos0(t *testing.T) {
s := newServer()
cl, r, w := newTestClient()
go func() {
err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet)
require.NoError(t, err)
_ = w.Close()
}()
buf, err := io.ReadAll(r)
require.NoError(t, err)
require.Equal(t, []byte{}, buf)
}
func TestServerProcessPacketPublishQos1PacketIDInUse(t *testing.T) {
s := newServer()
cl, r, w := newTestClient()
cl.State.Inflight.Set(packets.Packet{PacketID: 7, FixedHeader: packets.FixedHeader{Type: packets.Publish}})
atomic.StoreInt64(&s.Info.Inflight, 1)
go func() {
err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet)
require.NoError(t, err)
_ = w.Close()
}()
buf, err := io.ReadAll(r)
require.NoError(t, err)
require.Equal(t, packets.TPacketData[packets.Puback].Get(packets.TPuback).RawBytes, buf)
require.Equal(t, int64(0), atomic.LoadInt64(&s.Info.Inflight))
}
func TestServerProcessPacketPublishQos2PacketIDInUse(t *testing.T) {
s := newServer()
cl, r, w := newTestClient()
cl.Properties.ProtocolVersion = 5
cl.State.Inflight.Set(packets.Packet{PacketID: 7, FixedHeader: packets.FixedHeader{Type: packets.Pubrec}})
atomic.StoreInt64(&s.Info.Inflight, 1)
go func() {
err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos2Mqtt5).Packet)
require.NoError(t, err)
_ = w.Close()
}()
buf, err := io.ReadAll(r)
require.NoError(t, err)
require.Equal(t, packets.TPacketData[packets.Pubrec].Get(packets.TPubrecMqtt5IDInUse).RawBytes, buf)
require.Equal(t, int64(1), atomic.LoadInt64(&s.Info.Inflight))
}
func TestServerProcessPacketPublishQos1(t *testing.T) {
s := newServer()
cl, r, w := newTestClient()
go func() {
err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet)
require.NoError(t, err)
_ = w.Close()
}()
buf, err := io.ReadAll(r)
require.NoError(t, err)
require.Equal(t, packets.TPacketData[packets.Puback].Get(packets.TPuback).RawBytes, buf)
}
func TestServerProcessPacketPublishQos2(t *testing.T) {
s := newServer()
cl, r, w := newTestClient()
go func() {
err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos2).Packet)
require.NoError(t, err)
_ = w.Close()
}()
buf, err := io.ReadAll(r)
require.NoError(t, err)
require.Equal(t, packets.TPacketData[packets.Pubrec].Get(packets.TPubrec).RawBytes, buf)
}
func TestServerProcessPacketPublishDowngradeQos(t *testing.T) {
s := newServer()
s.Options.Capabilities.MaximumQos = 1
cl, r, w := newTestClient()
go func() {
err := s.processPacket(cl, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos2).Packet)
require.NoError(t, err)
_ = w.Close()
}()
buf, err := io.ReadAll(r)
require.NoError(t, err)
require.Equal(t, packets.TPacketData[packets.Puback].Get(packets.TPuback).RawBytes, buf)
}
func TestPublishToSubscribersSelfNoLocal(t *testing.T) {
s := newServer()
cl, r, w := newTestClient()
s.Clients.Add(cl)
subbed := s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/b/c", NoLocal: true})
require.True(t, subbed)
go func() {
pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet
pkx.Origin = cl.ID
s.publishToSubscribers(pkx)
time.Sleep(time.Millisecond)
_ = w.Close()
}()
receiverBuf := make(chan []byte)
go func() {
buf, err := io.ReadAll(r)
require.NoError(t, err)
receiverBuf <- buf
}()
require.Equal(t, []byte{}, <-receiverBuf)
}
func TestPublishToSubscribers(t *testing.T) {
s := newServer()
cl, r1, w1 := newTestClient()
cl.ID = "cl1"
cl2, r2, w2 := newTestClient()
cl2.ID = "cl2"
cl3, r3, w3 := newTestClient()
cl3.ID = "cl3"
s.Clients.Add(cl)
s.Clients.Add(cl2)
s.Clients.Add(cl3)
require.True(t, s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/b/c"}))
require.True(t, s.Topics.Subscribe(cl2.ID, packets.Subscription{Filter: SharePrefix + "/tmp/a/b/c"}))
require.True(t, s.Topics.Subscribe(cl3.ID, packets.Subscription{Filter: SharePrefix + "/tmp/a/b/c"}))
cl1Recv := make(chan []byte)
go func() {
buf, err := io.ReadAll(r1)
require.NoError(t, err)
cl1Recv <- buf
}()
cl2Recv := make(chan []byte)
go func() {
buf, err := io.ReadAll(r2)
require.NoError(t, err)
cl2Recv <- buf
}()
cl3Recv := make(chan []byte)
go func() {
buf, err := io.ReadAll(r3)
require.NoError(t, err)
cl3Recv <- buf
}()
go func() {
s.publishToSubscribers(*packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet)
time.Sleep(time.Millisecond)
_ = w1.Close()
_ = w2.Close()
_ = w3.Close()
}()
require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes, <-cl1Recv)
rcv2 := <-cl2Recv
rcv3 := <-cl3Recv
ok := false
if len(rcv2) > 0 {
require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes, rcv2)
require.Equal(t, []byte{}, rcv3)
ok = true
} else if len(rcv3) > 0 {
require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes, rcv3)
require.Equal(t, []byte{}, rcv2)
ok = true
}
require.True(t, ok)
}
func TestPublishToSubscribersMessageExpiryDelta(t *testing.T) {
s := newServer()
s.Options.Capabilities.MaximumMessageExpiryInterval = 86400
cl, r1, w1 := newTestClient()
cl.ID = "cl1"
cl.Properties.ProtocolVersion = 5
s.Clients.Add(cl)
require.True(t, s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/b/c"}))
cl1Recv := make(chan []byte)
go func() {
buf, err := io.ReadAll(r1)
require.NoError(t, err)
cl1Recv <- buf
}()
go func() {
pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet
pkx.Created = time.Now().Unix() - 30
s.publishToSubscribers(pkx)
time.Sleep(time.Millisecond)
_ = w1.Close()
}()
b := <-cl1Recv
pk := new(packets.Packet)
pk.ProtocolVersion = 5
require.Equal(t, uint32(s.Options.Capabilities.MaximumMessageExpiryInterval-30), binary.BigEndian.Uint32(b[11:15]))
}
func TestPublishToSubscribersIdentifiers(t *testing.T) {
s := newServer()
cl, r, w := newTestClient()
cl.Properties.ProtocolVersion = 5
s.Clients.Add(cl)
subbed := s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/b/+", Identifier: 2})
require.True(t, subbed)
subbed = s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/#", Identifier: 3})
require.True(t, subbed)
subbed = s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "d/e/f", Identifier: 4})
require.True(t, subbed)
go func() {
s.publishToSubscribers(*packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet)
time.Sleep(time.Millisecond)
_ = w.Close()
}()
receiverBuf := make(chan []byte)
go func() {
buf, err := io.ReadAll(r)
require.NoError(t, err)
receiverBuf <- buf
}()
require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishSubscriberIdentifier).RawBytes, <-receiverBuf)
}
func TestPublishToSubscribersPkIgnore(t *testing.T) {
s := newServer()
cl, r, w := newTestClient()
s.Clients.Add(cl)
subbed := s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "#", Identifier: 1})
require.True(t, subbed)
go func() {
pk := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet
pk.Ignore = true
s.publishToSubscribers(pk)
time.Sleep(time.Millisecond)
_ = w.Close()
}()
receiverBuf := make(chan []byte)
go func() {
buf, err := io.ReadAll(r)
require.NoError(t, err)
receiverBuf <- buf
}()
require.Equal(t, []byte{}, <-receiverBuf)
}
func TestPublishToClientServerDowngradeQos(t *testing.T) {
s := newServer()
s.Options.Capabilities.MaximumQos = 1
cl, r, w := newTestClient()
s.Clients.Add(cl)
_, ok := cl.State.Inflight.Get(1)
require.False(t, ok)
cl.State.packetID = 6 // just to match the same packet id (7) in the fixtures
go func() {
pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet
pkx.FixedHeader.Qos = 2
_, _ = s.publishToClient(cl, packets.Subscription{Filter: "a/b/c", Qos: 2}, pkx)
time.Sleep(time.Microsecond * 100)
_ = w.Close()
}()
receiverBuf := make(chan []byte)
go func() {
buf, err := io.ReadAll(r)
require.NoError(t, err)
receiverBuf <- buf
}()
require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).RawBytes, <-receiverBuf)
}
func TestPublishToClientSubscriptionDowngradeQos(t *testing.T) {
s := newServer()
s.Options.Capabilities.MaximumQos = 2
cl, r, w := newTestClient()
s.Clients.Add(cl)
_, ok := cl.State.Inflight.Get(1)
require.False(t, ok)
cl.State.packetID = 6 // just to match the same packet id (7) in the fixtures
go func() {
pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet
pkx.FixedHeader.Qos = 2
_, _ = s.publishToClient(cl, packets.Subscription{Filter: "a/b/c", Qos: 1}, pkx)
time.Sleep(time.Microsecond * 100)
_ = w.Close()
}()
receiverBuf := make(chan []byte)
go func() {
buf, err := io.ReadAll(r)
require.NoError(t, err)
receiverBuf <- buf
}()
require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).RawBytes, <-receiverBuf)
}
func TestPublishToClientExceedClientWritesPending(t *testing.T) {
var sendQuota uint16 = 5
s := newServer()
_, w := net.Pipe()
cl := newClient(w, &ops{
info: new(system.Info),
hooks: new(Hooks),
log: logger,
options: &Options{
Capabilities: &Capabilities{
MaximumClientWritesPending: 3,
maximumPacketID: 10,
},
},
})
cl.Properties.Props.ReceiveMaximum = sendQuota
cl.State.Inflight.ResetSendQuota(int32(cl.Properties.Props.ReceiveMaximum))
s.Clients.Add(cl)
for i := int32(0); i < cl.ops.options.Capabilities.MaximumClientWritesPending; i++ {
cl.State.outbound <- new(packets.Packet)
atomic.AddInt32(&cl.State.outboundQty, 1)
}
id, _ := cl.NextPacketID()
cl.State.Inflight.Set(packets.Packet{PacketID: uint16(id)})
cl.State.Inflight.DecreaseSendQuota()
sendQuota--
_, err := s.publishToClient(cl, packets.Subscription{Filter: "a/b/c", Qos: 2}, packets.Packet{})
require.Error(t, err)
require.ErrorIs(t, packets.ErrPendingClientWritesExceeded, err)
require.Equal(t, int32(sendQuota), atomic.LoadInt32(&cl.State.Inflight.sendQuota))
_, err = s.publishToClient(cl, packets.Subscription{Filter: "a/b/c", Qos: 2}, packets.Packet{FixedHeader: packets.FixedHeader{Qos: 1}})
require.Error(t, err)
require.ErrorIs(t, packets.ErrPendingClientWritesExceeded, err)
require.Equal(t, int32(sendQuota), atomic.LoadInt32(&cl.State.Inflight.sendQuota))
}
func TestPublishToClientServerTopicAlias(t *testing.T) {
s := newServer()
cl, r, w := newTestClient()
cl.Properties.ProtocolVersion = 5
cl.Properties.Props.TopicAliasMaximum = 5
s.Clients.Add(cl)
go func() {
pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasicMqtt5).Packet
_, _ = s.publishToClient(cl, packets.Subscription{Filter: pkx.TopicName}, pkx)
_, _ = s.publishToClient(cl, packets.Subscription{Filter: pkx.TopicName}, pkx)
time.Sleep(time.Millisecond)
_ = w.Close()
}()
receiverBuf := make(chan []byte)
go func() {
buf, err := io.ReadAll(r)
require.NoError(t, err)
receiverBuf <- buf
}()
ret := <-receiverBuf
pk1 := make([]byte, len(packets.TPacketData[packets.Publish].Get(packets.TPublishBasicMqtt5).RawBytes))
pk2 := make([]byte, len(packets.TPacketData[packets.Publish].Get(packets.TPublishBasicMqtt5).RawBytes)-5)
copy(pk1, ret[:len(packets.TPacketData[packets.Publish].Get(packets.TPublishBasicMqtt5).RawBytes)])
copy(pk2, ret[len(packets.TPacketData[packets.Publish].Get(packets.TPublishBasicMqtt5).RawBytes):])
require.Equal(t, append(pk1, pk2...), ret)
}
func TestPublishToClientMqtt3RetainFalseLeverageNoConn(t *testing.T) {
s := newServer()
cl, _, _ := newTestClient()
cl.Net.Conn = nil
out, err := s.publishToClient(cl, packets.Subscription{Filter: "a/b/c", RetainAsPublished: true}, *packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet)
require.False(t, out.FixedHeader.Retain)
require.Error(t, err)
require.ErrorIs(t, err, packets.CodeDisconnect)
}
func TestPublishToClientMqtt5RetainAsPublishedTrueLeverageNoConn(t *testing.T) {
s := newServer()
cl, _, _ := newTestClient()
cl.Properties.ProtocolVersion = 5
cl.Net.Conn = nil
out, err := s.publishToClient(cl, packets.Subscription{Filter: "a/b/c", RetainAsPublished: true}, *packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet)
require.True(t, out.FixedHeader.Retain)
require.Error(t, err)
require.ErrorIs(t, err, packets.CodeDisconnect)
}
func TestPublishToClientExceedMaximumInflight(t *testing.T) {
const MaxInflight uint16 = 5
s := newServer()
cl, _, _ := newTestClient()
s.Options.Capabilities.MaximumInflight = MaxInflight
cl.ops.options.Capabilities.MaximumInflight = MaxInflight
for i := uint16(0); i < MaxInflight; i++ {
cl.State.Inflight.Set(packets.Packet{PacketID: i})
}
_, err := s.publishToClient(cl, packets.Subscription{Filter: "a/b/c", Qos: 1}, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet)
require.Error(t, err)
require.ErrorIs(t, err, packets.ErrQuotaExceeded)
require.Equal(t, int64(1), atomic.LoadInt64(&s.Info.InflightDropped))
}
func TestPublishToClientExhaustedPacketID(t *testing.T) {
s := newServer()
cl, _, _ := newTestClient()
for i := uint32(0); i <= cl.ops.options.Capabilities.maximumPacketID; i++ {
cl.State.Inflight.Set(packets.Packet{PacketID: uint16(i)})
}
_, err := s.publishToClient(cl, packets.Subscription{Filter: "a/b/c", Qos: 1}, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet)
require.Error(t, err)
require.ErrorIs(t, err, packets.ErrQuotaExceeded)
require.Equal(t, int64(1), atomic.LoadInt64(&s.Info.InflightDropped))
}
func TestPublishToClientACLNotAuthorized(t *testing.T) {
s := New(&Options{
Logger: logger,
})
err := s.AddHook(new(DenyHook), nil)
require.NoError(t, err)
cl, _, _ := newTestClient()
_, err = s.publishToClient(cl, packets.Subscription{Filter: "a/b/c"}, *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet)
require.Error(t, err)
require.ErrorIs(t, err, packets.ErrNotAuthorized)
}
func TestPublishToClientNoConn(t *testing.T) {
s := newServer()
cl, _, _ := newTestClient()
cl.Net.Conn = nil
_, err := s.publishToClient(cl, packets.Subscription{Filter: "a/b/c"}, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet)
require.Error(t, err)
require.ErrorIs(t, err, packets.CodeDisconnect)
}
func TestProcessPublishWithTopicAlias(t *testing.T) {
s := newServer()
cl, r, w := newTestClient()
s.Clients.Add(cl)
subbed := s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/b/c", Qos: 0})
require.True(t, subbed)
cl2, _, w2 := newTestClient()
cl2.Properties.ProtocolVersion = 5
cl2.State.TopicAliases.Inbound.Set(1, "a/b/c")
go func() {
pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishMqtt5).Packet
pkx.Properties.SubscriptionIdentifier = []int{} // must not contain from client to server
pkx.TopicName = ""
pkx.Properties.TopicAlias = 1
_ = s.processPacket(cl2, pkx)
time.Sleep(time.Millisecond)
_ = w2.Close()
_ = w.Close()
}()
buf, err := io.ReadAll(r)
require.NoError(t, err)
require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes, buf)
}
func TestPublishToSubscribersExhaustedSendQuota(t *testing.T) {
s := newServer()
cl, r, w := newTestClient()
s.Clients.Add(cl)
cl.State.Inflight.sendQuota = 0
subbed := s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/b/c", Qos: 2})
require.True(t, subbed)
// coverage: subscriber publish errors are non-returnable
// can we hook into zerolog ?
_ = r.Close()
pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet
pkx.PacketID = 0
s.publishToSubscribers(pkx)
time.Sleep(time.Millisecond)
_ = w.Close()
}
func TestPublishToSubscribersExhaustedPacketIDs(t *testing.T) {
s := newServer()
cl, r, w := newTestClient()
s.Clients.Add(cl)
for i := uint32(0); i <= cl.ops.options.Capabilities.maximumPacketID; i++ {
cl.State.Inflight.Set(packets.Packet{PacketID: 1})
}
subbed := s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/b/c", Qos: 2})
require.True(t, subbed)
// coverage: subscriber publish errors are non-returnable
// can we hook into zerolog ?
_ = r.Close()
pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet
pkx.PacketID = 0
s.publishToSubscribers(pkx)
time.Sleep(time.Millisecond)
_ = w.Close()
}
func TestPublishToSubscribersNoConnection(t *testing.T) {
s := newServer()
cl, r, w := newTestClient()
s.Clients.Add(cl)
subbed := s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/b/c", Qos: 2})
require.True(t, subbed)
// coverage: subscriber publish errors are non-returnable
// can we hook into zerolog ?
_ = r.Close()
s.publishToSubscribers(*packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet)
time.Sleep(time.Millisecond)
_ = w.Close()
}
func TestPublishRetainedToClient(t *testing.T) {
s := newServer()
cl, r, w := newTestClient()
s.Clients.Add(cl)
subbed := s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/b/c", Qos: 2})
require.True(t, subbed)
retained := s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishRetainMqtt5).Packet)
require.Equal(t, int64(1), retained)
go func() {
s.publishRetainedToClient(cl, packets.Subscription{Filter: "a/b/c"}, false)
time.Sleep(time.Millisecond)
_ = w.Close()
}()
buf, err := io.ReadAll(r)
require.NoError(t, err)
require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).RawBytes, buf)
}
func TestPublishRetainedToClientIsShared(t *testing.T) {
s := newServer()
cl, r, w := newTestClient()
s.Clients.Add(cl)
sub := packets.Subscription{Filter: SharePrefix + "/test/a/b/c"}
subbed := s.Topics.Subscribe(cl.ID, sub)
require.True(t, subbed)
go func() {
s.publishRetainedToClient(cl, sub, false)
_ = w.Close()
}()
buf, err := io.ReadAll(r)
require.NoError(t, err)
require.Equal(t, []byte{}, buf)
}
func TestPublishRetainedToClientError(t *testing.T) {
s := newServer()
cl, _, w := newTestClient()
s.Clients.Add(cl)
sub := packets.Subscription{Filter: "a/b/c"}
subbed := s.Topics.Subscribe(cl.ID, sub)
require.True(t, subbed)
retained := s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet)
require.Equal(t, int64(1), retained)
_ = w.Close()
s.publishRetainedToClient(cl, sub, false)
}
func TestNoRetainMessageIfUnavailable(t *testing.T) {
s := newServer()
s.Options.Capabilities.RetainAvailable = 0
cl, _, _ := newTestClient()
s.Clients.Add(cl)
s.retainMessage(new(Client), *packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet)
require.Equal(t, int64(0), atomic.LoadInt64(&s.Info.Retained))
}
func TestNoRetainMessageIfPkIgnore(t *testing.T) {
s := newServer()
cl, _, _ := newTestClient()
s.Clients.Add(cl)
pk := *packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet
pk.Ignore = true
s.retainMessage(new(Client), pk)
require.Equal(t, int64(0), atomic.LoadInt64(&s.Info.Retained))
}
func TestNoRetainMessage(t *testing.T) {
s := newServer()
cl, _, _ := newTestClient()
s.Clients.Add(cl)
s.retainMessage(new(Client), *packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet)
require.Equal(t, int64(1), atomic.LoadInt64(&s.Info.Retained))
}
func TestServerProcessPacketPuback(t *testing.T) {
tt := ProtocolTest{
{
protocolVersion: 4,
in: packets.TPacketData[packets.Puback].Get(packets.TPuback),
},
{
protocolVersion: 5,
in: packets.TPacketData[packets.Puback].Get(packets.TPubackMqtt5),
},
}
for _, tx := range tt {
t.Run(strconv.Itoa(int(tx.protocolVersion)), func(t *testing.T) {
pID := uint16(7)
s := newServer()
cl, _, _ := newTestClient()
cl.State.Inflight.sendQuota = 3
cl.State.Inflight.receiveQuota = 3
cl.State.Inflight.Set(packets.Packet{PacketID: pID})
atomic.AddInt64(&s.Info.Inflight, 1)
err := s.processPacket(cl, *tx.in.Packet)
require.NoError(t, err)
require.Equal(t, int32(4), atomic.LoadInt32(&cl.State.Inflight.sendQuota))
require.Equal(t, int32(3), atomic.LoadInt32(&cl.State.Inflight.receiveQuota))
require.Equal(t, int64(0), atomic.LoadInt64(&s.Info.Inflight))
_, ok := cl.State.Inflight.Get(pID)
require.False(t, ok)
})
}
}
func TestServerProcessPacketPubackNoPacketID(t *testing.T) {
s := newServer()
cl, _, _ := newTestClient()
cl.State.Inflight.sendQuota = 3
cl.State.Inflight.receiveQuota = 3
pk := *packets.TPacketData[packets.Puback].Get(packets.TPuback).Packet
err := s.processPacket(cl, pk)
require.NoError(t, err)
require.Equal(t, int32(3), atomic.LoadInt32(&cl.State.Inflight.sendQuota))
require.Equal(t, int32(3), atomic.LoadInt32(&cl.State.Inflight.receiveQuota))
}
func TestServerProcessPacketPubrec(t *testing.T) {
pID := uint16(7)
s := newServer()
cl, r, w := newTestClient()
cl.State.Inflight.sendQuota = 3
cl.State.Inflight.receiveQuota = 3
cl.State.Inflight.Set(packets.Packet{PacketID: pID})
atomic.AddInt64(&s.Info.Inflight, 1)
recv := make(chan []byte)
go func() { // receive the ack
buf, err := io.ReadAll(r)
require.NoError(t, err)
recv <- buf
}()
err := s.processPacket(cl, *packets.TPacketData[packets.Pubrec].Get(packets.TPubrec).Packet)
require.NoError(t, err)
_ = w.Close()
require.Equal(t, packets.TPacketData[packets.Pubrel].Get(packets.TPubrel).RawBytes, <-recv)
require.Equal(t, int32(2), atomic.LoadInt32(&cl.State.Inflight.receiveQuota))
require.Equal(t, int32(3), atomic.LoadInt32(&cl.State.Inflight.sendQuota))
require.Equal(t, int64(1), atomic.LoadInt64(&s.Info.Inflight))
_, ok := cl.State.Inflight.Get(pID)
require.True(t, ok)
}
func TestServerProcessPacketPubrecNoPacketID(t *testing.T) {
s := newServer()
cl, r, w := newTestClient()
cl.Properties.ProtocolVersion = 5
cl.State.Inflight.sendQuota = 3
cl.State.Inflight.receiveQuota = 3
recv := make(chan []byte)
go func() { // receive the ack
buf, err := io.ReadAll(r)
require.NoError(t, err)
recv <- buf
}()
pk := *packets.TPacketData[packets.Pubrec].Get(packets.TPubrec).Packet // not sending properties
err := s.processPacket(cl, pk)
require.NoError(t, err)
_ = w.Close()
require.Equal(t, packets.TPacketData[packets.Pubrel].Get(packets.TPubrelMqtt5AckNoPacket).RawBytes, <-recv)
require.Equal(t, int32(3), atomic.LoadInt32(&cl.State.Inflight.sendQuota))
require.Equal(t, int32(3), atomic.LoadInt32(&cl.State.Inflight.receiveQuota))
}
func TestServerProcessPacketPubrecInvalidReason(t *testing.T) {
pID := uint16(7)
s := newServer()
cl, _, _ := newTestClient()
cl.State.Inflight.Set(packets.Packet{PacketID: pID})
err := s.processPacket(cl, *packets.TPacketData[packets.Pubrec].Get(packets.TPubrecInvalidReason).Packet)
require.NoError(t, err)
require.Equal(t, int64(-1), atomic.LoadInt64(&s.Info.Inflight))
_, ok := cl.State.Inflight.Get(pID)
require.False(t, ok)
}
func TestServerProcessPacketPubrecFailure(t *testing.T) {
pID := uint16(7)
s := newServer()
cl, _, _ := newTestClient()
cl.State.Inflight.Set(packets.Packet{PacketID: pID})
cl.Stop(packets.CodeDisconnect)
err := s.processPacket(cl, *packets.TPacketData[packets.Pubrec].Get(packets.TPubrec).Packet)
require.Error(t, err)
require.ErrorIs(t, cl.StopCause(), packets.CodeDisconnect)
}
func TestServerProcessPacketPubrel(t *testing.T) {
pID := uint16(7)
s := newServer()
cl, r, w := newTestClient()
cl.State.Inflight.sendQuota = 3
cl.State.Inflight.receiveQuota = 3
cl.State.Inflight.Set(packets.Packet{PacketID: pID})
atomic.AddInt64(&s.Info.Inflight, 1)
recv := make(chan []byte)
go func() { // receive the ack
buf, err := io.ReadAll(r)
require.NoError(t, err)
recv <- buf
}()
err := s.processPacket(cl, *packets.TPacketData[packets.Pubrel].Get(packets.TPubrel).Packet)
require.NoError(t, err)
_ = w.Close()
require.Equal(t, int32(4), atomic.LoadInt32(&cl.State.Inflight.receiveQuota))
require.Equal(t, int32(4), atomic.LoadInt32(&cl.State.Inflight.sendQuota))
require.Equal(t, packets.TPacketData[packets.Pubcomp].Get(packets.TPubcomp).RawBytes, <-recv)
require.Equal(t, int64(0), atomic.LoadInt64(&s.Info.Inflight))
_, ok := cl.State.Inflight.Get(pID)
require.False(t, ok)
}
func TestServerProcessPacketPubrelNoPacketID(t *testing.T) {
s := newServer()
cl, r, w := newTestClient()
cl.Properties.ProtocolVersion = 5
cl.State.Inflight.sendQuota = 3
cl.State.Inflight.receiveQuota = 3
recv := make(chan []byte)
go func() { // receive the ack
buf, err := io.ReadAll(r)
require.NoError(t, err)
recv <- buf
}()
pk := *packets.TPacketData[packets.Pubrel].Get(packets.TPubrel).Packet // not sending properties
err := s.processPacket(cl, pk)
require.NoError(t, err)
_ = w.Close()
require.Equal(t, packets.TPacketData[packets.Pubcomp].Get(packets.TPubcompMqtt5AckNoPacket).RawBytes, <-recv)
require.Equal(t, int32(3), atomic.LoadInt32(&cl.State.Inflight.sendQuota))
require.Equal(t, int32(3), atomic.LoadInt32(&cl.State.Inflight.receiveQuota))
}
func TestServerProcessPacketPubrelFailure(t *testing.T) {
pID := uint16(7)
s := newServer()
cl, _, _ := newTestClient()
cl.State.Inflight.Set(packets.Packet{PacketID: pID})
cl.Stop(packets.CodeDisconnect)
err := s.processPacket(cl, *packets.TPacketData[packets.Pubrel].Get(packets.TPubrel).Packet)
require.Error(t, err)
require.ErrorIs(t, cl.StopCause(), packets.CodeDisconnect)
}
func TestServerProcessPacketPubrelBadReason(t *testing.T) {
pID := uint16(7)
s := newServer()
cl, _, _ := newTestClient()
cl.State.Inflight.Set(packets.Packet{PacketID: pID})
err := s.processPacket(cl, *packets.TPacketData[packets.Pubrel].Get(packets.TPubrelInvalidReason).Packet)
require.NoError(t, err)
require.Equal(t, int64(-1), atomic.LoadInt64(&s.Info.Inflight))
_, ok := cl.State.Inflight.Get(pID)
require.False(t, ok)
}
func TestServerProcessPacketPubcomp(t *testing.T) {
tt := ProtocolTest{
{
protocolVersion: 4,
in: packets.TPacketData[packets.Pubcomp].Get(packets.TPubcomp),
},
{
protocolVersion: 5,
in: packets.TPacketData[packets.Pubcomp].Get(packets.TPubcompMqtt5),
},
}
for _, tx := range tt {
t.Run(strconv.Itoa(int(tx.protocolVersion)), func(t *testing.T) {
pID := uint16(7)
s := newServer()
cl, _, _ := newTestClient()
cl.Properties.ProtocolVersion = tx.protocolVersion
cl.State.Inflight.sendQuota = 3
cl.State.Inflight.receiveQuota = 3
cl.State.Inflight.Set(packets.Packet{PacketID: pID})
atomic.AddInt64(&s.Info.Inflight, 1)
err := s.processPacket(cl, *tx.in.Packet)
require.NoError(t, err)
require.Equal(t, int64(0), atomic.LoadInt64(&s.Info.Inflight))
require.Equal(t, int32(4), atomic.LoadInt32(&cl.State.Inflight.receiveQuota))
require.Equal(t, int32(4), atomic.LoadInt32(&cl.State.Inflight.sendQuota))
_, ok := cl.State.Inflight.Get(pID)
require.False(t, ok)
})
}
}
func TestServerProcessInboundQos2Flow(t *testing.T) {
tt := ProtocolTest{
{
protocolVersion: 5,
in: packets.TPacketData[packets.Publish].Get(packets.TPublishQos2),
out: packets.TPacketData[packets.Pubrec].Get(packets.TPubrec),
data: map[string]any{
"sendquota": int32(3),
"recvquota": int32(2),
"inflight": int64(1),
},
},
{
protocolVersion: 5,
in: packets.TPacketData[packets.Pubrel].Get(packets.TPubrel),
out: packets.TPacketData[packets.Pubcomp].Get(packets.TPubcomp),
data: map[string]any{
"sendquota": int32(4),
"recvquota": int32(3),
"inflight": int64(0),
},
},
}
pID := uint16(7)
s := newServer()
cl, r, w := newTestClient()
cl.State.Inflight.sendQuota = 3
cl.State.Inflight.receiveQuota = 3
for i, tx := range tt {
t.Run("qos step"+strconv.Itoa(i), func(t *testing.T) {
r, w = net.Pipe()
cl.Net.Conn = w
recv := make(chan []byte)
go func() { // receive the ack
buf, err := io.ReadAll(r)
require.NoError(t, err)
recv <- buf
}()
err := s.processPacket(cl, *tx.in.Packet)
require.NoError(t, err)
_ = w.Close()
require.Equal(t, tx.out.RawBytes, <-recv)
if i == 0 {
_, ok := cl.State.Inflight.Get(pID)
require.True(t, ok)
}
require.Equal(t, tx.data["inflight"].(int64), atomic.LoadInt64(&s.Info.Inflight))
require.Equal(t, tx.data["recvquota"].(int32), atomic.LoadInt32(&cl.State.Inflight.receiveQuota))
require.Equal(t, tx.data["sendquota"].(int32), atomic.LoadInt32(&cl.State.Inflight.sendQuota))
})
}
_, ok := cl.State.Inflight.Get(pID)
require.False(t, ok)
}
func TestServerProcessOutboundQos2Flow(t *testing.T) {
tt := ProtocolTest{
{
protocolVersion: 5,
in: packets.TPacketData[packets.Publish].Get(packets.TPublishQos2),
out: packets.TPacketData[packets.Publish].Get(packets.TPublishQos2),
data: map[string]any{
"sendquota": int32(2),
"recvquota": int32(3),
"inflight": int64(1),
},
},
{
protocolVersion: 5,
in: packets.TPacketData[packets.Pubrec].Get(packets.TPubrec),
out: packets.TPacketData[packets.Pubrel].Get(packets.TPubrel),
data: map[string]any{
"sendquota": int32(2),
"recvquota": int32(2),
"inflight": int64(1),
},
},
{
protocolVersion: 5,
in: packets.TPacketData[packets.Pubcomp].Get(packets.TPubcomp),
data: map[string]any{
"sendquota": int32(3),
"recvquota": int32(3),
"inflight": int64(0),
},
},
}
pID := uint16(6)
s := newServer()
cl, _, _ := newTestClient()
cl.State.packetID = uint32(6)
cl.State.Inflight.sendQuota = 3
cl.State.Inflight.receiveQuota = 3
s.Clients.Add(cl)
s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/b/c", Qos: 2})
for i, tx := range tt {
t.Run("qos step"+strconv.Itoa(i), func(t *testing.T) {
r, w := net.Pipe()
time.Sleep(time.Millisecond)
cl.Net.Conn = w
recv := make(chan []byte)
go func() { // receive the ack
buf, err := io.ReadAll(r)
require.NoError(t, err)
recv <- buf
}()
if i == 0 {
s.publishToSubscribers(*tx.in.Packet)
} else {
err := s.processPacket(cl, *tx.in.Packet)
require.NoError(t, err)
}
time.Sleep(time.Millisecond)
_ = w.Close()
if i != 2 {
require.Equal(t, tx.out.RawBytes, <-recv)
}
require.Equal(t, tx.data["inflight"].(int64), atomic.LoadInt64(&s.Info.Inflight))
require.Equal(t, tx.data["recvquota"].(int32), atomic.LoadInt32(&cl.State.Inflight.receiveQuota))
require.Equal(t, tx.data["sendquota"].(int32), atomic.LoadInt32(&cl.State.Inflight.sendQuota))
})
}
_, ok := cl.State.Inflight.Get(pID)
require.False(t, ok)
}
func TestServerProcessPacketSubscribe(t *testing.T) {
s := newServer()
cl, r, w := newTestClient()
cl.Properties.ProtocolVersion = 5
go func() {
err := s.processPacket(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribeMqtt5).Packet)
require.NoError(t, err)
_ = w.Close()
}()
buf, err := io.ReadAll(r)
require.NoError(t, err)
require.Equal(t, packets.TPacketData[packets.Suback].Get(packets.TSubackMqtt5).RawBytes, buf)
}
func TestServerProcessPacketSubscribePacketIDInUse(t *testing.T) {
s := newServer()
cl, r, w := newTestClient()
cl.Properties.ProtocolVersion = 5
cl.State.Inflight.Set(packets.Packet{PacketID: 15, FixedHeader: packets.FixedHeader{Type: packets.Publish}})
pkx := *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribeMqtt5).Packet
pkx.PacketID = 15
go func() {
err := s.processPacket(cl, pkx)
require.NoError(t, err)
_ = w.Close()
}()
buf, err := io.ReadAll(r)
require.NoError(t, err)
require.Equal(t, packets.TPacketData[packets.Suback].Get(packets.TSubackPacketIDInUse).RawBytes, buf)
}
func TestServerProcessPacketSubscribeInvalid(t *testing.T) {
s := newServer()
cl, _, _ := newTestClient()
cl.Properties.ProtocolVersion = 5
err := s.processPacket(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribeSpecQosMustPacketID).Packet)
require.Error(t, err)
require.ErrorIs(t, err, packets.ErrProtocolViolationNoPacketID)
}
func TestServerProcessPacketSubscribeInvalidFilter(t *testing.T) {
s := newServer()
cl, r, w := newTestClient()
cl.Properties.ProtocolVersion = 5
go func() {
err := s.processPacket(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribeInvalidFilter).Packet)
require.NoError(t, err)
_ = w.Close()
}()
buf, err := io.ReadAll(r)
require.NoError(t, err)
require.Equal(t, packets.TPacketData[packets.Suback].Get(packets.TSubackInvalidFilter).RawBytes, buf)
}
func TestServerProcessPacketSubscribeInvalidSharedNoLocal(t *testing.T) {
s := newServer()
cl, r, w := newTestClient()
cl.Properties.ProtocolVersion = 5
go func() {
err := s.processPacket(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribeInvalidSharedNoLocal).Packet)
require.NoError(t, err)
_ = w.Close()
}()
buf, err := io.ReadAll(r)
require.NoError(t, err)
require.Equal(t, packets.TPacketData[packets.Suback].Get(packets.TSubackInvalidSharedNoLocal).RawBytes, buf)
}
func TestServerProcessSubscribeWithRetain(t *testing.T) {
s := newServer()
cl, r, w := newTestClient()
retained := s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet)
require.Equal(t, int64(1), retained)
go func() {
err := s.processPacket(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribe).Packet)
require.NoError(t, err)
time.Sleep(time.Millisecond)
_ = w.Close()
}()
buf, err := io.ReadAll(r)
require.NoError(t, err)
require.Equal(t, append(
packets.TPacketData[packets.Suback].Get(packets.TSuback).RawBytes,
packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).RawBytes...,
), buf)
}
func TestServerProcessSubscribeDowngradeQos(t *testing.T) {
s := newServer()
s.Options.Capabilities.MaximumQos = 1
cl, r, w := newTestClient()
go func() {
err := s.processPacket(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribeMany).Packet)
require.NoError(t, err)
time.Sleep(time.Millisecond)
_ = w.Close()
}()
buf, err := io.ReadAll(r)
require.NoError(t, err)
require.Equal(t, []byte{0, 1, 1}, buf[4:])
}
func TestServerProcessSubscribeWithRetainHandling1(t *testing.T) {
s := newServer()
cl, r, w := newTestClient()
s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/b/c"})
s.Clients.Add(cl)
retained := s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet)
require.Equal(t, int64(1), retained)
go func() {
err := s.processPacket(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribeRetainHandling1).Packet)
require.NoError(t, err)
time.Sleep(time.Millisecond)
_ = w.Close()
}()
buf, err := io.ReadAll(r)
require.NoError(t, err)
require.Equal(t, packets.TPacketData[packets.Suback].Get(packets.TSuback).RawBytes, buf)
}
func TestServerProcessSubscribeWithRetainHandling2(t *testing.T) {
s := newServer()
cl, r, w := newTestClient()
s.Clients.Add(cl)
retained := s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet)
require.Equal(t, int64(1), retained)
go func() {
err := s.processPacket(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribeRetainHandling2).Packet)
require.NoError(t, err)
time.Sleep(time.Millisecond)
_ = w.Close()
}()
buf, err := io.ReadAll(r)
require.NoError(t, err)
require.Equal(t, packets.TPacketData[packets.Suback].Get(packets.TSuback).RawBytes, buf)
}
func TestServerProcessSubscribeWithNotRetainAsPublished(t *testing.T) {
s := newServer()
cl, r, w := newTestClient()
s.Clients.Add(cl)
retained := s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet)
require.Equal(t, int64(1), retained)
go func() {
err := s.processPacket(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribeRetainAsPublished).Packet)
require.NoError(t, err)
time.Sleep(time.Millisecond)
_ = w.Close()
}()
buf, err := io.ReadAll(r)
require.NoError(t, err)
require.Equal(t, append(
packets.TPacketData[packets.Suback].Get(packets.TSuback).RawBytes,
packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).RawBytes...,
), buf)
}
func TestServerProcessSubscribeNoConnection(t *testing.T) {
s := newServer()
cl, r, _ := newTestClient()
_ = r.Close()
err := s.processSubscribe(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribe).Packet)
require.Error(t, err)
require.ErrorIs(t, err, io.ErrClosedPipe)
}
func TestServerProcessSubscribeACLCheckDeny(t *testing.T) {
s := New(&Options{
Logger: logger,
})
_ = s.Serve()
cl, r, w := newTestClient()
cl.Properties.ProtocolVersion = 5
go func() {
err := s.processSubscribe(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribe).Packet)
require.NoError(t, err)
_ = w.Close()
}()
buf, err := io.ReadAll(r)
require.NoError(t, err)
require.Equal(t, packets.TPacketData[packets.Suback].Get(packets.TSubackDeny).RawBytes, buf)
}
func TestServerProcessSubscribeACLCheckDenyObscure(t *testing.T) {
s := New(&Options{
Logger: logger,
})
_ = s.Serve()
s.Options.Capabilities.Compatibilities.ObscureNotAuthorized = true
cl, r, w := newTestClient()
cl.Properties.ProtocolVersion = 5
go func() {
err := s.processSubscribe(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribe).Packet)
require.NoError(t, err)
_ = w.Close()
}()
buf, err := io.ReadAll(r)
require.NoError(t, err)
require.Equal(t, packets.TPacketData[packets.Suback].Get(packets.TSubackUnspecifiedErrorMqtt5).RawBytes, buf)
}
func TestServerProcessSubscribeErrorDowngrade(t *testing.T) {
s := newServer()
cl, r, w := newTestClient()
cl.Properties.ProtocolVersion = 3
cl.State.packetID = 1 // just to match the same packet id (7) in the fixtures
go func() {
err := s.processPacket(cl, *packets.TPacketData[packets.Subscribe].Get(packets.TSubscribeInvalidSharedNoLocal).Packet)
require.NoError(t, err)
_ = w.Close()
}()
buf, err := io.ReadAll(r)
require.NoError(t, err)
require.Equal(t, packets.TPacketData[packets.Suback].Get(packets.TSubackUnspecifiedError).RawBytes, buf)
}
func TestServerProcessPacketUnsubscribe(t *testing.T) {
s := newServer()
cl, r, w := newTestClient()
cl.Properties.ProtocolVersion = 5
s.Topics.Subscribe(cl.ID, packets.Subscription{Filter: "a/b", Qos: 0})
go func() {
err := s.processPacket(cl, *packets.TPacketData[packets.Unsubscribe].Get(packets.TUnsubscribeMqtt5).Packet)
require.NoError(t, err)
_ = w.Close()
}()
buf, err := io.ReadAll(r)
require.NoError(t, err)
require.Equal(t, packets.TPacketData[packets.Unsuback].Get(packets.TUnsubackMqtt5).RawBytes, buf)
require.Equal(t, int64(-1), atomic.LoadInt64(&s.Info.Subscriptions))
}
func TestServerProcessPacketUnsubscribePackedIDInUse(t *testing.T) {
s := newServer()
cl, r, w := newTestClient()
cl.Properties.ProtocolVersion = 5
cl.State.Inflight.Set(packets.Packet{PacketID: 15, FixedHeader: packets.FixedHeader{Type: packets.Publish}})
go func() {
err := s.processPacket(cl, *packets.TPacketData[packets.Unsubscribe].Get(packets.TUnsubscribeMqtt5).Packet)
require.NoError(t, err)
_ = w.Close()
}()
buf, err := io.ReadAll(r)
require.NoError(t, err)
require.Equal(t, packets.TPacketData[packets.Unsuback].Get(packets.TUnsubackPacketIDInUse).RawBytes, buf)
require.Equal(t, int64(0), atomic.LoadInt64(&s.Info.Subscriptions))
}
func TestServerProcessPacketUnsubscribeInvalid(t *testing.T) {
s := newServer()
cl, _, _ := newTestClient()
err := s.processPacket(cl, *packets.TPacketData[packets.Unsubscribe].Get(packets.TUnsubscribeSpecQosMustPacketID).Packet)
require.Error(t, err)
require.ErrorIs(t, err, packets.ErrProtocolViolationNoPacketID)
}
func TestServerReceivePacketError(t *testing.T) {
s := newServer()
cl, _, _ := newTestClient()
err := s.receivePacket(cl, *packets.TPacketData[packets.Unsubscribe].Get(packets.TUnsubscribeSpecQosMustPacketID).Packet)
require.Error(t, err)
require.ErrorIs(t, err, packets.ErrProtocolViolationNoPacketID)
}
func TestServerRecievePacketDisconnectClientZeroNonZero(t *testing.T) {
s := newServer()
cl, r, w := newTestClient()
cl.Properties.Props.SessionExpiryInterval = 0
cl.Properties.ProtocolVersion = 5
cl.Properties.Props.RequestProblemInfo = 0
cl.Properties.Props.RequestProblemInfoFlag = true
go func() {
err := s.receivePacket(cl, *packets.TPacketData[packets.Disconnect].Get(packets.TDisconnectMqtt5).Packet)
require.Error(t, err)
require.ErrorIs(t, err, packets.ErrProtocolViolationZeroNonZeroExpiry)
_ = w.Close()
}()
buf, err := io.ReadAll(r)
require.NoError(t, err)
require.Equal(t, packets.TPacketData[packets.Disconnect].Get(packets.TDisconnectZeroNonZeroExpiry).RawBytes, buf)
}
func TestServerRecievePacketDisconnectClient(t *testing.T) {
s := newServer()
cl, r, w := newTestClient()
go func() {
err := s.DisconnectClient(cl, packets.CodeDisconnect)
require.NoError(t, err)
_ = w.Close()
}()
buf, err := io.ReadAll(r)
require.NoError(t, err)
require.Equal(t, packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes, buf)
}
func TestServerProcessPacketDisconnect(t *testing.T) {
s := newServer()
cl, _, _ := newTestClient()
cl.Properties.Props.SessionExpiryInterval = 30
cl.Properties.ProtocolVersion = 5
s.loop.willDelayed.Add(cl.ID, packets.Packet{TopicName: "a/b/c", Payload: []byte("hello")})
require.Equal(t, 1, s.loop.willDelayed.Len())
err := s.processPacket(cl, *packets.TPacketData[packets.Disconnect].Get(packets.TDisconnectMqtt5).Packet)
require.NoError(t, err)
require.Equal(t, 0, s.loop.willDelayed.Len())
require.True(t, cl.Closed())
require.Equal(t, time.Now().Unix(), atomic.LoadInt64(&cl.State.disconnected))
}
func TestServerProcessPacketDisconnectNonZeroExpiryViolation(t *testing.T) {
s := newServer()
cl, _, _ := newTestClient()
cl.Properties.Props.SessionExpiryInterval = 0
cl.Properties.ProtocolVersion = 5
cl.Properties.Props.RequestProblemInfo = 0
cl.Properties.Props.RequestProblemInfoFlag = true
err := s.processPacket(cl, *packets.TPacketData[packets.Disconnect].Get(packets.TDisconnectMqtt5).Packet)
require.Error(t, err)
require.ErrorIs(t, err, packets.ErrProtocolViolationZeroNonZeroExpiry)
}
func TestServerProcessPacketAuth(t *testing.T) {
s := newServer()
cl, r, w := newTestClient()
go func() {
err := s.processPacket(cl, *packets.TPacketData[packets.Auth].Get(packets.TAuth).Packet)
require.NoError(t, err)
_ = w.Close()
}()
buf, err := io.ReadAll(r)
require.NoError(t, err)
require.Equal(t, []byte{}, buf)
}
func TestServerProcessPacketAuthInvalidReason(t *testing.T) {
s := newServer()
cl, _, _ := newTestClient()
pkx := *packets.TPacketData[packets.Auth].Get(packets.TAuth).Packet
pkx.ReasonCode = 99
err := s.processPacket(cl, pkx)
require.Error(t, err)
require.ErrorIs(t, packets.ErrProtocolViolationInvalidReason, err)
}
func TestServerProcessPacketAuthFailure(t *testing.T) {
s := newServer()
cl, _, _ := newTestClient()
hook := new(modifiedHookBase)
hook.fail = true
err := s.AddHook(hook, nil)
require.NoError(t, err)
err = s.processAuth(cl, *packets.TPacketData[packets.Auth].Get(packets.TAuth).Packet)
require.Error(t, err)
require.ErrorIs(t, errTestHook, err)
}
func TestServerSendLWT(t *testing.T) {
s := newServer()
_ = s.Serve()
defer s.Close()
sender, _, w1 := newTestClient()
sender.ID = "sender"
sender.Properties.Will = Will{
Flag: 1,
TopicName: "a/b/c",
Payload: []byte("hello mochi"),
}
s.Clients.Add(sender)
receiver, r2, w2 := newTestClient()
receiver.ID = "receiver"
s.Clients.Add(receiver)
s.Topics.Subscribe(receiver.ID, packets.Subscription{Filter: "a/b/c", Qos: 0})
require.Equal(t, int64(0), atomic.LoadInt64(&s.Info.PacketsReceived))
require.Equal(t, 0, len(s.Topics.Messages("a/b/c")))
receiverBuf := make(chan []byte)
go func() {
buf, err := io.ReadAll(r2)
require.NoError(t, err)
receiverBuf <- buf
}()
go func() {
s.sendLWT(sender)
time.Sleep(time.Millisecond * 10)
_ = w1.Close()
_ = w2.Close()
}()
require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes, <-receiverBuf)
}
func TestServerSendLWTRetain(t *testing.T) {
s := newServer()
_ = s.Serve()
defer s.Close()
sender, _, w1 := newTestClient()
sender.ID = "sender"
sender.Properties.Will = Will{
Flag: 1,
TopicName: "a/b/c",
Payload: []byte("hello mochi"),
Retain: true,
}
s.Clients.Add(sender)
receiver, r2, w2 := newTestClient()
receiver.ID = "receiver"
s.Clients.Add(receiver)
s.Topics.Subscribe(receiver.ID, packets.Subscription{Filter: "a/b/c", Qos: 0})
require.Equal(t, int64(0), atomic.LoadInt64(&s.Info.PacketsReceived))
require.Equal(t, 0, len(s.Topics.Messages("a/b/c")))
receiverBuf := make(chan []byte)
go func() {
buf, err := io.ReadAll(r2)
require.NoError(t, err)
receiverBuf <- buf
}()
go func() {
s.sendLWT(sender)
time.Sleep(time.Millisecond * 10)
_ = w1.Close()
_ = w2.Close()
}()
require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes, <-receiverBuf)
}
func TestServerSendLWTDelayed(t *testing.T) {
s := newServer()
cl1, _, _ := newTestClient()
cl1.ID = "cl1"
cl1.Properties.Will = Will{
Flag: 1,
TopicName: "a/b/c",
Payload: []byte("hello mochi"),
Retain: true,
WillDelayInterval: 2,
}
s.Clients.Add(cl1)
cl2, r, w := newTestClient()
cl2.ID = "cl2"
s.Clients.Add(cl2)
require.True(t, s.Topics.Subscribe(cl2.ID, packets.Subscription{Filter: "a/b/c"}))
go func() {
s.sendLWT(cl1)
pk, ok := s.loop.willDelayed.Get(cl1.ID)
require.True(t, ok)
pk.Expiry = time.Now().Unix() - 1 // set back expiry time
s.loop.willDelayed.Add(cl1.ID, pk)
require.Equal(t, 1, s.loop.willDelayed.Len())
s.sendDelayedLWT(time.Now().Unix())
require.Equal(t, 0, s.loop.willDelayed.Len())
time.Sleep(time.Millisecond)
_ = w.Close()
}()
recv := make(chan []byte)
go func() {
buf, err := io.ReadAll(r)
require.NoError(t, err)
recv <- buf
}()
require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).RawBytes, <-recv)
}
func TestServerReadStore(t *testing.T) {
s := newServer()
hook := new(modifiedHookBase)
_ = s.AddHook(hook, nil)
hook.failAt = 1 // clients
err := s.readStore()
require.Error(t, err)
hook.failAt = 2 // subscriptions
err = s.readStore()
require.Error(t, err)
hook.failAt = 3 // inflight
err = s.readStore()
require.Error(t, err)
hook.failAt = 4 // retained
err = s.readStore()
require.Error(t, err)
hook.failAt = 5 // sys info
err = s.readStore()
require.Error(t, err)
}
func TestServerLoadClients(t *testing.T) {
v := []storage.Client{
{ID: "mochi"},
{ID: "zen"},
{ID: "mochi-co"},
{ID: "v3-clean", ProtocolVersion: 4, Clean: true},
{ID: "v3-not-clean", ProtocolVersion: 4, Clean: false},
{
ID: "v5-clean",
ProtocolVersion: 5,
Clean: true,
Properties: storage.ClientProperties{
SessionExpiryInterval: 10,
},
},
{
ID: "v5-expire-interval-0",
ProtocolVersion: 5,
Properties: storage.ClientProperties{
SessionExpiryInterval: 0,
},
},
{
ID: "v5-expire-interval-not-0",
ProtocolVersion: 5,
Properties: storage.ClientProperties{
SessionExpiryInterval: 10,
},
},
}
s := newServer()
require.Equal(t, 0, s.Clients.Len())
s.loadClients(v)
require.Equal(t, 6, s.Clients.Len())
cl, ok := s.Clients.Get("mochi")
require.True(t, ok)
require.Equal(t, "mochi", cl.ID)
_, ok = s.Clients.Get("v3-clean")
require.False(t, ok)
_, ok = s.Clients.Get("v3-not-clean")
require.True(t, ok)
_, ok = s.Clients.Get("v5-clean")
require.True(t, ok)
_, ok = s.Clients.Get("v5-expire-interval-0")
require.False(t, ok)
_, ok = s.Clients.Get("v5-expire-interval-not-0")
require.True(t, ok)
}
func TestServerLoadSubscriptions(t *testing.T) {
v := []storage.Subscription{
{ID: "sub1", Client: "mochi", Filter: "a/b/c"},
{ID: "sub2", Client: "mochi", Filter: "d/e/f", Qos: 1},
{ID: "sub3", Client: "mochi", Filter: "h/i/j", Qos: 2},
}
s := newServer()
cl, _, _ := newTestClient()
s.Clients.Add(cl)
require.Equal(t, 0, cl.State.Subscriptions.Len())
s.loadSubscriptions(v)
require.Equal(t, 3, cl.State.Subscriptions.Len())
}
func TestServerLoadInflightMessages(t *testing.T) {
s := newServer()
s.loadClients([]storage.Client{
{ID: "mochi"},
{ID: "zen"},
{ID: "mochi-co"},
})
require.Equal(t, 3, s.Clients.Len())
v := []storage.Message{
{Origin: "mochi", PacketID: 1, Payload: []byte("hello world"), TopicName: "a/b/c"},
{Origin: "mochi", PacketID: 2, Payload: []byte("yes"), TopicName: "a/b/c"},
{Origin: "zen", PacketID: 3, Payload: []byte("hello world"), TopicName: "a/b/c"},
{Origin: "mochi-co", PacketID: 4, Payload: []byte("hello world"), TopicName: "a/b/c"},
}
s.loadInflight(v)
cl, ok := s.Clients.Get("mochi")
require.True(t, ok)
require.Equal(t, "mochi", cl.ID)
msg, ok := cl.State.Inflight.Get(2)
require.True(t, ok)
require.Equal(t, []byte{'y', 'e', 's'}, msg.Payload)
require.Equal(t, "a/b/c", msg.TopicName)
cl, ok = s.Clients.Get("mochi-co")
require.True(t, ok)
msg, ok = cl.State.Inflight.Get(4)
require.True(t, ok)
}
func TestServerLoadRetainedMessages(t *testing.T) {
s := newServer()
v := []storage.Message{
{Origin: "mochi", FixedHeader: packets.FixedHeader{Retain: true}, Payload: []byte("hello world"), TopicName: "a/b/c"},
{Origin: "mochi-co", FixedHeader: packets.FixedHeader{Retain: true}, Payload: []byte("yes"), TopicName: "d/e/f"},
{Origin: "zen", FixedHeader: packets.FixedHeader{Retain: true}, Payload: []byte("hello world"), TopicName: "h/i/j"},
}
s.loadRetained(v)
require.Equal(t, 1, len(s.Topics.Messages("a/b/c")))
require.Equal(t, 1, len(s.Topics.Messages("d/e/f")))
require.Equal(t, 1, len(s.Topics.Messages("h/i/j")))
require.Equal(t, 0, len(s.Topics.Messages("w/x/y")))
}
func TestServerClose(t *testing.T) {
s := newServer()
hook := new(modifiedHookBase)
_ = s.AddHook(hook, nil)
cl, r, _ := newTestClient()
cl.Net.Listener = "t1"
cl.Properties.ProtocolVersion = 5
s.Clients.Add(cl)
err := s.AddListener(listeners.NewMockListener("t1", ":1882"))
require.NoError(t, err)
_ = s.Serve()
// receive the disconnect
recv := make(chan []byte)
go func() {
buf, err := io.ReadAll(r)
require.NoError(t, err)
recv <- buf
}()
time.Sleep(time.Millisecond)
require.Equal(t, 1, s.Listeners.Len())
listener, ok := s.Listeners.Get("t1")
require.Equal(t, true, ok)
require.Equal(t, true, listener.(*listeners.MockListener).IsServing())
_ = s.Close()
time.Sleep(time.Millisecond)
require.Equal(t, false, listener.(*listeners.MockListener).IsServing())
require.Equal(t, packets.TPacketData[packets.Disconnect].Get(packets.TDisconnectShuttingDown).RawBytes, <-recv)
}
func TestServerClearExpiredInflights(t *testing.T) {
s := New(nil)
require.NotNil(t, s)
s.Options.Capabilities.MaximumMessageExpiryInterval = 4
n := time.Now().Unix()
cl, _, _ := newTestClient()
cl.ops.info = s.Info
cl.State.Inflight.Set(packets.Packet{PacketID: 1, Expiry: n - 1})
cl.State.Inflight.Set(packets.Packet{PacketID: 2, Expiry: n - 2})
cl.State.Inflight.Set(packets.Packet{PacketID: 3, Created: n - 3}) // within bounds
cl.State.Inflight.Set(packets.Packet{PacketID: 5, Created: n - 5}) // over max server expiry limit
cl.State.Inflight.Set(packets.Packet{PacketID: 7, Created: n})
s.Clients.Add(cl)
require.Len(t, cl.State.Inflight.GetAll(false), 5)
s.clearExpiredInflights(n)
require.Len(t, cl.State.Inflight.GetAll(false), 2)
require.Equal(t, int64(-3), s.Info.Inflight)
s.Options.Capabilities.MaximumMessageExpiryInterval = 0
cl.State.Inflight.Set(packets.Packet{PacketID: 8, Expiry: n - 8})
s.clearExpiredInflights(n)
require.Len(t, cl.State.Inflight.GetAll(false), 3)
}
func TestServerClearExpiredRetained(t *testing.T) {
s := New(nil)
require.NotNil(t, s)
s.Options.Capabilities.MaximumMessageExpiryInterval = 4
n := time.Now().Unix()
s.Topics.Retained.Add("a/b/c", packets.Packet{ProtocolVersion: 5, Created: n, Expiry: n - 1})
s.Topics.Retained.Add("d/e/f", packets.Packet{ProtocolVersion: 5, Created: n, Expiry: n - 2})
s.Topics.Retained.Add("g/h/i", packets.Packet{ProtocolVersion: 5, Created: n - 3}) // within bounds
s.Topics.Retained.Add("j/k/l", packets.Packet{ProtocolVersion: 5, Created: n - 5}) // over max server expiry limit
s.Topics.Retained.Add("m/n/o", packets.Packet{ProtocolVersion: 5, Created: n})
require.Len(t, s.Topics.Retained.GetAll(), 5)
s.clearExpiredRetainedMessages(n)
require.Len(t, s.Topics.Retained.GetAll(), 2)
s.Topics.Retained.Add("p/q/r", packets.Packet{Created: n, Expiry: n - 1})
s.Topics.Retained.Add("s/t/u", packets.Packet{Created: n, Expiry: n - 2}) // expiry is ineffective for v3.
s.Topics.Retained.Add("v/w/x", packets.Packet{Created: n - 3}) // within bounds for v3
s.Topics.Retained.Add("y/z/1", packets.Packet{Created: n - 5}) // over max server expiry limit
require.Len(t, s.Topics.Retained.GetAll(), 6)
s.clearExpiredRetainedMessages(n)
require.Len(t, s.Topics.Retained.GetAll(), 5)
s.Options.Capabilities.MaximumMessageExpiryInterval = 0
s.Topics.Retained.Add("2/3/4", packets.Packet{Created: n - 8})
s.clearExpiredRetainedMessages(n)
require.Len(t, s.Topics.Retained.GetAll(), 6)
}
func TestServerClearExpiredClients(t *testing.T) {
s := New(nil)
require.NotNil(t, s)
n := time.Now().Unix()
cl, _, _ := newTestClient()
cl.ID = "cl"
s.Clients.Add(cl)
// No Expiry
cl0, _, _ := newTestClient()
cl0.ID = "c0"
cl0.State.disconnected = n - 10
cl0.State.cancelOpen()
cl0.Properties.ProtocolVersion = 5
cl0.Properties.Props.SessionExpiryInterval = 12
cl0.Properties.Props.SessionExpiryIntervalFlag = true
s.Clients.Add(cl0)
// Normal Expiry
cl1, _, _ := newTestClient()
cl1.ID = "c1"
cl1.State.disconnected = n - 10
cl1.State.cancelOpen()
cl1.Properties.ProtocolVersion = 5
cl1.Properties.Props.SessionExpiryInterval = 8
cl1.Properties.Props.SessionExpiryIntervalFlag = true
s.Clients.Add(cl1)
// No Expiry, indefinite session
cl2, _, _ := newTestClient()
cl2.ID = "c2"
cl2.State.disconnected = n - 10
cl2.State.cancelOpen()
cl2.Properties.ProtocolVersion = 5
cl2.Properties.Props.SessionExpiryInterval = 0
cl2.Properties.Props.SessionExpiryIntervalFlag = true
s.Clients.Add(cl2)
require.Equal(t, 4, s.Clients.Len())
s.clearExpiredClients(n)
require.Equal(t, 2, s.Clients.Len())
}
func TestLoadServerInfoRestoreOnRestart(t *testing.T) {
s := New(nil)
s.Options.Capabilities.Compatibilities.RestoreSysInfoOnRestart = true
info := system.Info{
BytesReceived: 60,
}
s.loadServerInfo(info)
require.Equal(t, int64(60), s.Info.BytesReceived)
}
func TestItoa(t *testing.T) {
i := int64(22)
require.Equal(t, "22", Int64toa(i))
}
func TestServerSubscribe(t *testing.T) {
handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) {}
s := newServerWithInlineClient()
require.NotNil(t, s)
tt := []struct {
desc string
filter string
identifier int
handler InlineSubFn
expect error
}{
{
desc: "subscribe",
filter: "a/b/c",
identifier: 1,
handler: handler,
expect: nil,
},
{
desc: "re-subscribe",
filter: "a/b/c",
identifier: 1,
handler: handler,
expect: nil,
},
{
desc: "subscribe d/e/f",
filter: "d/e/f",
identifier: 1,
handler: handler,
expect: nil,
},
{
desc: "re-subscribe d/e/f by different identifier",
filter: "d/e/f",
identifier: 2,
handler: handler,
expect: nil,
},
{
desc: "subscribe different handler",
filter: "a/b/c",
identifier: 1,
handler: func(cl *Client, sub packets.Subscription, pk packets.Packet) {},
expect: nil,
},
{
desc: "subscribe $SYS/info",
filter: "$SYS/info",
identifier: 1,
handler: handler,
expect: nil,
},
{
desc: "subscribe invalid ###",
filter: "###",
identifier: 1,
handler: handler,
expect: packets.ErrTopicFilterInvalid,
},
{
desc: "subscribe invalid handler",
filter: "a/b/c",
identifier: 1,
handler: nil,
expect: packets.ErrInlineSubscriptionHandlerInvalid,
},
}
for _, tx := range tt {
t.Run(tx.desc, func(t *testing.T) {
require.Equal(t, tx.expect, s.Subscribe(tx.filter, tx.identifier, tx.handler))
})
}
}
func TestServerSubscribeNoInlineClient(t *testing.T) {
s := newServer()
err := s.Subscribe("a/b/c", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) {})
require.Error(t, err)
require.ErrorIs(t, err, ErrInlineClientNotEnabled)
}
func TestServerUnsubscribe(t *testing.T) {
handler := func(cl *Client, sub packets.Subscription, pk packets.Packet) {
// handler logic
}
s := newServerWithInlineClient()
err := s.Subscribe("a/b/c", 1, handler)
require.Nil(t, err)
err = s.Subscribe("d/e/f", 1, handler)
require.Nil(t, err)
err = s.Subscribe("d/e/f", 2, handler)
require.Nil(t, err)
err = s.Unsubscribe("a/b/c", 1)
require.Nil(t, err)
err = s.Unsubscribe("d/e/f", 1)
require.Nil(t, err)
err = s.Unsubscribe("d/e/f", 2)
require.Nil(t, err)
err = s.Unsubscribe("not/exist", 1)
require.Nil(t, err)
err = s.Unsubscribe("#/#/invalid", 1)
require.Equal(t, packets.ErrTopicFilterInvalid, err)
}
func TestServerUnsubscribeNoInlineClient(t *testing.T) {
s := newServer()
err := s.Unsubscribe("a/b/c", 1)
require.Error(t, err)
require.ErrorIs(t, err, ErrInlineClientNotEnabled)
}
func TestPublishToInlineSubscriber(t *testing.T) {
s := newServerWithInlineClient()
finishCh := make(chan bool)
err := s.Subscribe("a/b/c", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) {
require.Equal(t, []byte("hello mochi"), pk.Payload)
require.Equal(t, InlineClientId, cl.ID)
require.Equal(t, LocalListener, cl.Net.Listener)
require.Equal(t, "a/b/c", sub.Filter)
require.Equal(t, 1, sub.Identifier)
finishCh <- true
})
require.Nil(t, err)
go func() {
pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet
s.publishToSubscribers(pkx)
}()
require.Equal(t, true, <-finishCh)
}
func TestPublishToInlineSubscribersDifferentFilter(t *testing.T) {
s := newServerWithInlineClient()
subNumber := 2
finishCh := make(chan bool, subNumber)
err := s.Subscribe("a/b/c", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) {
require.Equal(t, []byte("hello mochi"), pk.Payload)
require.Equal(t, InlineClientId, cl.ID)
require.Equal(t, LocalListener, cl.Net.Listener)
require.Equal(t, "a/b/c", sub.Filter)
require.Equal(t, 1, sub.Identifier)
finishCh <- true
})
require.Nil(t, err)
err = s.Subscribe("z/e/n", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) {
require.Equal(t, []byte("mochi mochi"), pk.Payload)
require.Equal(t, InlineClientId, cl.ID)
require.Equal(t, LocalListener, cl.Net.Listener)
require.Equal(t, "z/e/n", sub.Filter)
require.Equal(t, 1, sub.Identifier)
finishCh <- true
})
require.Nil(t, err)
go func() {
pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet
s.publishToSubscribers(pkx)
pkx = *packets.TPacketData[packets.Publish].Get(packets.TPublishCopyBasic).Packet
s.publishToSubscribers(pkx)
}()
for i := 0; i < subNumber; i++ {
require.Equal(t, true, <-finishCh)
}
}
func TestPublishToInlineSubscribersDifferentIdentifier(t *testing.T) {
s := newServerWithInlineClient()
subNumber := 2
finishCh := make(chan bool, subNumber)
err := s.Subscribe("a/b/c", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) {
require.Equal(t, []byte("hello mochi"), pk.Payload)
require.Equal(t, InlineClientId, cl.ID)
require.Equal(t, LocalListener, cl.Net.Listener)
require.Equal(t, "a/b/c", sub.Filter)
require.Equal(t, 1, sub.Identifier)
finishCh <- true
})
require.Nil(t, err)
err = s.Subscribe("a/b/c", 2, func(cl *Client, sub packets.Subscription, pk packets.Packet) {
require.Equal(t, []byte("hello mochi"), pk.Payload)
require.Equal(t, InlineClientId, cl.ID)
require.Equal(t, LocalListener, cl.Net.Listener)
require.Equal(t, "a/b/c", sub.Filter)
require.Equal(t, 2, sub.Identifier)
finishCh <- true
})
require.Nil(t, err)
go func() {
pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishBasic).Packet
s.publishToSubscribers(pkx)
}()
for i := 0; i < subNumber; i++ {
require.Equal(t, true, <-finishCh)
}
}
func TestServerSubscribeWithRetain(t *testing.T) {
s := newServerWithInlineClient()
subNumber := 1
finishCh := make(chan bool, subNumber)
retained := s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet)
require.Equal(t, int64(1), retained)
err := s.Subscribe("a/b/c", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) {
require.Equal(t, []byte("hello mochi"), pk.Payload)
require.Equal(t, InlineClientId, cl.ID)
require.Equal(t, LocalListener, cl.Net.Listener)
require.Equal(t, "a/b/c", sub.Filter)
require.Equal(t, 1, sub.Identifier)
finishCh <- true
})
require.Nil(t, err)
require.Equal(t, true, <-finishCh)
}
func TestServerSubscribeWithRetainDifferentFilter(t *testing.T) {
s := newServerWithInlineClient()
subNumber := 2
finishCh := make(chan bool, subNumber)
retained := s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet)
require.Equal(t, int64(1), retained)
retained = s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishCopyBasic).Packet)
require.Equal(t, int64(1), retained)
err := s.Subscribe("a/b/c", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) {
require.Equal(t, []byte("hello mochi"), pk.Payload)
require.Equal(t, InlineClientId, cl.ID)
require.Equal(t, LocalListener, cl.Net.Listener)
require.Equal(t, "a/b/c", sub.Filter)
require.Equal(t, 1, sub.Identifier)
finishCh <- true
})
require.Nil(t, err)
err = s.Subscribe("z/e/n", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) {
require.Equal(t, []byte("mochi mochi"), pk.Payload)
require.Equal(t, InlineClientId, cl.ID)
require.Equal(t, LocalListener, cl.Net.Listener)
require.Equal(t, "z/e/n", sub.Filter)
require.Equal(t, 1, sub.Identifier)
finishCh <- true
})
require.Nil(t, err)
for i := 0; i < subNumber; i++ {
require.Equal(t, true, <-finishCh)
}
}
func TestServerSubscribeWithRetainDifferentIdentifier(t *testing.T) {
s := newServerWithInlineClient()
subNumber := 2
finishCh := make(chan bool, subNumber)
retained := s.Topics.RetainMessage(*packets.TPacketData[packets.Publish].Get(packets.TPublishRetain).Packet)
require.Equal(t, int64(1), retained)
err := s.Subscribe("a/b/c", 1, func(cl *Client, sub packets.Subscription, pk packets.Packet) {
require.Equal(t, []byte("hello mochi"), pk.Payload)
require.Equal(t, InlineClientId, cl.ID)
require.Equal(t, LocalListener, cl.Net.Listener)
require.Equal(t, "a/b/c", sub.Filter)
require.Equal(t, 1, sub.Identifier)
finishCh <- true
})
require.Nil(t, err)
err = s.Subscribe("a/b/c", 2, func(cl *Client, sub packets.Subscription, pk packets.Packet) {
require.Equal(t, []byte("hello mochi"), pk.Payload)
require.Equal(t, InlineClientId, cl.ID)
require.Equal(t, LocalListener, cl.Net.Listener)
require.Equal(t, "a/b/c", sub.Filter)
require.Equal(t, 2, sub.Identifier)
finishCh <- true
})
require.Nil(t, err)
for i := 0; i < subNumber; i++ {
require.Equal(t, true, <-finishCh)
}
}
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/vanve/server.git
git@gitee.com:vanve/server.git
vanve
server
server
file-based-config

搜索帮助