From ee878b25c9c5910208aa413d4d97ab2204b4dcdd Mon Sep 17 00:00:00 2001 From: notify Date: Tue, 22 Jul 2025 14:01:10 +0800 Subject: [PATCH 1/4] =?UTF-8?q?cbor=E5=92=8Cjson=E4=B8=A4=E6=89=8B?= =?UTF-8?q?=E6=8A=93?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lua/server/rpc/cbor.lua | 594 +++++++++++++++++++++++++++++++++ lua/server/rpc/dispatchers.lua | 14 +- lua/server/rpc/entry.lua | 42 ++- lua/server/rpc/fk.lua | 56 +++- lua/server/rpc/jsonrpc.lua | 89 +++-- lua/server/rpc/stdio.lua | 3 + 6 files changed, 744 insertions(+), 54 deletions(-) create mode 100644 lua/server/rpc/cbor.lua diff --git a/lua/server/rpc/cbor.lua b/lua/server/rpc/cbor.lua new file mode 100644 index 0000000..b220e9a --- /dev/null +++ b/lua/server/rpc/cbor.lua @@ -0,0 +1,594 @@ +-- SPDX-License-Identifier: MIT +-- Copied from https://github.com/Zash/lua-cbor + +-- Concise Binary Object Representation (CBOR) +-- RFC 7049 + +local function softreq(pkg, field) + local ok, mod = pcall(require, pkg); + if not ok then return end + if field then return mod[field]; end + return mod; +end +local dostring = function(s) + local ok, f = load(function() + local ret = s; + s = nil + return ret; + end); + if ok and f then + return f(); + end +end + +local setmetatable = setmetatable; +local getmetatable = getmetatable; +local dbg_getmetatable = debug.getmetatable; +local assert = assert; +local error = error; +local type = type; +local pairs = pairs; +local ipairs = ipairs; +local tostring = tostring; +local s_char = string.char; +local t_concat = table.concat; +local t_sort = table.sort; +local m_floor = math.floor; +local m_abs = math.abs; +local m_huge = math.huge; +local m_max = math.max; +local maxint = math.maxinteger or 9007199254740992; +local minint = math.mininteger or -9007199254740992; +local NaN = 0 / 0; +local m_frexp = math.frexp; +local m_ldexp = math.ldexp or function(x, exp) return x * 2.0 ^ exp; end; +local m_type = math.type or function(n) return n % 1 == 0 and n <= maxint and n >= minint and "integer" or "float" end; +local s_pack = string.pack or softreq("struct", "pack"); +local s_unpack = string.unpack or softreq("struct", "unpack"); +local b_rshift = softreq("bit32", "rshift") or softreq("bit", "rshift") or + dostring "return function(a,b) return a >> b end" or + function(a, b) return m_max(0, m_floor(a / (2 ^ b))); end; + +-- sanity check +if s_pack and s_pack(">I2", 0) ~= "\0\0" then + s_pack = nil; +end +if s_unpack and s_unpack(">I2", "\1\2\3\4") ~= 0x102 then + s_unpack = nil; +end + +local _ENV = nil; -- luacheck: ignore 211 + +local encoder = {}; + +local function encode(obj, opts) + return encoder[type(obj)](obj, opts); +end + +-- Major types 0, 1 and length encoding for others +local function integer(num, m) + if m == 0 and num < 0 then + -- negative integer, major type 1 + num, m = -num - 1, 32; + end + if num < 24 then + return s_char(m + num); + elseif num < 2 ^ 8 then + return s_char(m + 24, num); + elseif num < 2 ^ 16 then + return s_char(m + 25, b_rshift(num, 8), num % 0x100); + elseif num < 2 ^ 32 then + return s_char(m + 26, + b_rshift(num, 24) % 0x100, + b_rshift(num, 16) % 0x100, + b_rshift(num, 8) % 0x100, + num % 0x100); + elseif num < 2 ^ 64 then + local high = m_floor(num / 2 ^ 32); + num = num % 2 ^ 32; + return s_char(m + 27, + b_rshift(high, 24) % 0x100, + b_rshift(high, 16) % 0x100, + b_rshift(high, 8) % 0x100, + high % 0x100, + b_rshift(num, 24) % 0x100, + b_rshift(num, 16) % 0x100, + b_rshift(num, 8) % 0x100, + num % 0x100); + end + error "int too large"; +end + +if s_pack then + function integer(num, m) + local fmt; + m = m or 0; + if num < 24 then + fmt, m = ">B", m + num; + elseif num < 256 then + fmt, m = ">BB", m + 24; + elseif num < 65536 then + fmt, m = ">BI2", m + 25; + elseif num < 4294967296 then + fmt, m = ">BI4", m + 26; + else + fmt, m = ">BI8", m + 27; + end + return s_pack(fmt, m, num); + end +end + +local simple_mt = {}; +function simple_mt:__tostring() return self.name or ("simple(%d)"):format(self.value); end + +function simple_mt:__tocbor() return self.cbor or integer(self.value, 224); end + +local function simple(value, name, cbor) + assert(value >= 0 and value <= 255, "bad argument #1 to 'simple' (integer in range 0..255 expected)"); + return setmetatable({ value = value, name = name, cbor = cbor }, simple_mt); +end + +local tagged_mt = {}; +function tagged_mt:__tostring() return ("%d(%s)"):format(self.tag, tostring(self.value)); end + +function tagged_mt:__tocbor(opts) return integer(self.tag, 192) .. encode(self.value, opts); end + +local function tagged(tag, value) + assert(tag >= 0, "bad argument #1 to 'tagged' (positive integer expected)"); + return setmetatable({ tag = tag, value = value }, tagged_mt); +end + +local null = simple(22, "null"); -- explicit null +local undefined = simple(23, "undefined"); -- undefined or nil +local BREAK = simple(31, "break", "\255"); + +-- Number types dispatch +function encoder.number(num) + return encoder[m_type(num)](num); +end + +-- Major types 0, 1 +function encoder.integer(num) + if num < 0 then + return integer(-1 - num, 32); + end + return integer(num, 0); +end + +-- Major type 7 +function encoder.float(num) + if num ~= num then -- NaN shortcut + return "\251\127\255\255\255\255\255\255\255"; + end + local sign = (num > 0 or 1 / num > 0) and 0 or 1; + num = m_abs(num) + if num == m_huge then + return s_char(251, sign * 128 + 128 - 1) .. "\240\0\0\0\0\0\0"; + end + local fraction, exponent = m_frexp(num) + if fraction == 0 then + return s_char(251, sign * 128) .. "\0\0\0\0\0\0\0"; + end + fraction = fraction * 2; + exponent = exponent + 1024 - 2; + if exponent <= 0 then + fraction = fraction * 2 ^ (exponent - 1) + exponent = 0; + else + fraction = fraction - 1; + end + return s_char(251, + sign * 2 ^ 7 + m_floor(exponent / 2 ^ 4) % 2 ^ 7, + exponent % 2 ^ 4 * 2 ^ 4 + + m_floor(fraction * 2 ^ 4 % 0x100), + m_floor(fraction * 2 ^ 12 % 0x100), + m_floor(fraction * 2 ^ 20 % 0x100), + m_floor(fraction * 2 ^ 28 % 0x100), + m_floor(fraction * 2 ^ 36 % 0x100), + m_floor(fraction * 2 ^ 44 % 0x100), + m_floor(fraction * 2 ^ 52 % 0x100) + ) +end + +if s_pack then + function encoder.float(num) + return s_pack(">Bd", 251, num); + end +end + + +-- Major type 2 - byte strings +function encoder.bytestring(s) + return integer(#s, 64) .. s; +end + +-- Major type 3 - UTF-8 strings +function encoder.utf8string(s) + return integer(#s, 96) .. s; +end + +-- Lua strings are byte strings +encoder.string = encoder.bytestring; + +function encoder.boolean(bool) + return bool and "\245" or "\244"; +end + +encoder["nil"] = function() return "\246"; end + +function encoder.userdata(ud, opts) + local mt = dbg_getmetatable(ud); + if mt then + local encode_ud = opts and opts[mt] or mt.__tocbor; + if encode_ud then + return encode_ud(ud, opts); + end + end + error "can't encode userdata"; +end + +function encoder.table(t, opts) + local mt = getmetatable(t); + if mt then + local encode_t = opts and opts[mt] or mt.__tocbor; + if encode_t then + return encode_t(t, opts); + end + end + -- the table is encoded as an array iff when we iterate over it, + -- we see successive integer keys starting from 1. The lua + -- language doesn't actually guarantee that this will be the case + -- when we iterate over a table with successive integer keys, but + -- due an implementation detail in PUC Rio Lua, this is what we + -- usually observe. See the Lua manual regarding the # (length) + -- operator. In the case that this does not happen, we will fall + -- back to a map with integer keys, which becomes a bit larger. + local array, map, i, p = { integer(#t, 128) }, { "\191" }, 1, 2; + local is_array = true; + for k, v in pairs(t) do + is_array = is_array and i == k; + i = i + 1; + + local encoded_v = encode(v, opts); + array[i] = encoded_v; + + map[p], p = encode(k, opts), p + 1; + map[p], p = encoded_v, p + 1; + end + -- map[p] = "\255"; + map[1] = integer(i - 1, 160); + return t_concat(is_array and array or map); +end + +-- Array or dict-only encoders, which can be set as __tocbor metamethod +function encoder.array(t, opts) + local array = {}; + for i, v in ipairs(t) do + array[i] = encode(v, opts); + end + return integer(#array, 128) .. t_concat(array); +end + +function encoder.map(t, opts) + local map, p, len = { "\191" }, 2, 0; + for k, v in pairs(t) do + map[p], p = encode(k, opts), p + 1; + map[p], p = encode(v, opts), p + 1; + len = len + 1; + end + -- map[p] = "\255"; + map[1] = integer(len, 160); + return t_concat(map); +end + +encoder.dict = encoder.map; -- COMPAT + +function encoder.ordered_map(t, opts) + local map = {}; + if not t[1] then -- no predefined order + local i = 0; + for k in pairs(t) do + i = i + 1; + map[i] = k; + end + t_sort(map); + end + for i, k in ipairs(t[1] and t or map) do + map[i] = encode(k, opts) .. encode(t[k], opts); + end + return integer(#map, 160) .. t_concat(map); +end + +encoder["function"] = function() + error "can't encode function"; +end + +-- Decoder +-- Reads from a file-handle like object +local function read_bytes(fh, len) + return fh:read(len); +end + +local function read_byte(fh) + return fh:read(1):byte(); +end + +local function read_length(fh, mintyp) + if mintyp < 24 then + return mintyp; + elseif mintyp < 28 then + local out = 0; + for _ = 1, 2 ^ (mintyp - 24) do + out = out * 256 + read_byte(fh); + end + return out; + else + error "invalid length"; + end +end + +local decoder = {}; + +local function read_type(fh) + local byte = read_byte(fh); + return b_rshift(byte, 5), byte % 32; +end + +local function read_object(fh, opts) + local typ, mintyp = read_type(fh); + return decoder[typ](fh, mintyp, opts); +end + +local function read_integer(fh, mintyp) + return read_length(fh, mintyp); +end + +local function read_negative_integer(fh, mintyp) + return -1 - read_length(fh, mintyp); +end + +local function read_string(fh, mintyp) + if mintyp ~= 31 then + return read_bytes(fh, read_length(fh, mintyp)); + end + local out = {}; + local i = 1; + local v = read_object(fh); + while v ~= BREAK do + out[i], i = v, i + 1; + v = read_object(fh); + end + return t_concat(out); +end + +local function read_unicode_string(fh, mintyp) + return read_string(fh, mintyp); + -- local str = read_string(fh, mintyp); + -- if have_utf8 and not utf8.len(str) then + -- TODO How to handle this? + -- end + -- return str; +end + +local function read_array(fh, mintyp, opts) + local out = {}; + if mintyp == 31 then + local i = 1; + local v = read_object(fh, opts); + while v ~= BREAK do + out[i], i = v, i + 1; + v = read_object(fh, opts); + end + else + local len = read_length(fh, mintyp); + for i = 1, len do + out[i] = read_object(fh, opts); + end + end + return out; +end + +local function read_map(fh, mintyp, opts) + local out = {}; + local k; + if mintyp == 31 then + local i = 1; + k = read_object(fh, opts); + while k ~= BREAK do + out[k], i = read_object(fh, opts), i + 1; + k = read_object(fh, opts); + end + else + local len = read_length(fh, mintyp); + for _ = 1, len do + k = read_object(fh, opts); + out[k] = read_object(fh, opts); + end + end + return out; +end + +local tagged_decoders = {}; + +local function read_semantic(fh, mintyp, opts) + local tag = read_length(fh, mintyp); + local value = read_object(fh, opts); + local postproc = opts and opts[tag] or tagged_decoders[tag]; + if postproc then + return postproc(value); + end + return tagged(tag, value); +end + +local function read_half_float(fh) + local exponent = read_byte(fh); + local fraction = read_byte(fh); + local sign = exponent < 128 and 1 or -1; -- sign is highest bit + + fraction = fraction + (exponent * 256) % 1024; -- copy two(?) bits from exponent to fraction + exponent = b_rshift(exponent, 2) % 32; -- remove sign bit and two low bits from fraction; + + if exponent == 0 then + return sign * m_ldexp(fraction, -24); + elseif exponent ~= 31 then + return sign * m_ldexp(fraction + 1024, exponent - 25); + elseif fraction == 0 then + return sign * m_huge; + else + return NaN; + end +end + +local function read_float(fh) + local exponent = read_byte(fh); + local fraction = read_byte(fh); + local sign = exponent < 128 and 1 or -1; -- sign is highest bit + exponent = exponent * 2 % 256 + b_rshift(fraction, 7); + fraction = fraction % 128; + fraction = fraction * 256 + read_byte(fh); + fraction = fraction * 256 + read_byte(fh); + + if exponent == 0 then + return sign * m_ldexp(exponent, -149); + elseif exponent ~= 0xff then + return sign * m_ldexp(fraction + 2 ^ 23, exponent - 150); + elseif fraction == 0 then + return sign * m_huge; + else + return NaN; + end +end + +local function read_double(fh) + local exponent = read_byte(fh); + local fraction = read_byte(fh); + local sign = exponent < 128 and 1 or -1; -- sign is highest bit + + exponent = exponent % 128 * 16 + b_rshift(fraction, 4); + fraction = fraction % 16; + fraction = fraction * 256 + read_byte(fh); + fraction = fraction * 256 + read_byte(fh); + fraction = fraction * 256 + read_byte(fh); + fraction = fraction * 256 + read_byte(fh); + fraction = fraction * 256 + read_byte(fh); + fraction = fraction * 256 + read_byte(fh); + + if exponent == 0 then + return sign * m_ldexp(exponent, -149); + elseif exponent ~= 0xff then + return sign * m_ldexp(fraction + 2 ^ 52, exponent - 1075); + elseif fraction == 0 then + return sign * m_huge; + else + return NaN; + end +end + + +if s_unpack then + function read_float(fh) return s_unpack(">f", read_bytes(fh, 4)) end + + function read_double(fh) return s_unpack(">d", read_bytes(fh, 8)) end +end + +local function read_simple(fh, value, opts) + if value == 24 then + value = read_byte(fh); + end + if value == 20 then + return false; + elseif value == 21 then + return true; + elseif value == 22 then + return null; + elseif value == 23 then + return undefined; + elseif value == 25 then + return read_half_float(fh); + elseif value == 26 then + return read_float(fh); + elseif value == 27 then + return read_double(fh); + elseif value == 31 then + return BREAK; + end + if opts and opts.simple then + return opts.simple(value); + end + return simple(value); +end + +decoder[0] = read_integer; +decoder[1] = read_negative_integer; +decoder[2] = read_string; +decoder[3] = read_unicode_string; +decoder[4] = read_array; +decoder[5] = read_map; +decoder[6] = read_semantic; +decoder[7] = read_simple; + +-- opts.more(n) -> want more data +-- opts.simple -> decode simple value +-- opts[int] -> tagged decoder +local function decode(s, opts) + local fh = {}; + local pos = 1; + + local more; + if type(opts) == "function" then + more = opts; + elseif type(opts) == "table" then + more = opts.more; + elseif opts ~= nil then + error(("bad argument #2 to 'decode' (function or table expected, got %s)"):format(type(opts))); + end + if type(more) ~= "function" then + function more() + error "input too short"; + end + end + + function fh:read(bytes) + local ret = s:sub(pos, pos + bytes - 1); + if #ret < bytes then + ret = more(bytes - #ret, fh, opts); + if ret then self:write(ret); end + return self:read(bytes); + end + pos = pos + bytes; + return ret; + end + + function fh:write(bytes) -- luacheck: no self + s = s .. bytes; + if pos > 256 then + s = s:sub(pos + 1); + pos = 1; + end + return #bytes; + end + + return read_object(fh, opts); +end + +return { + -- en-/decoder functions + encode = encode, + decode = decode, + decode_file = read_object, + + -- tables of per-type en-/decoders + type_encoders = encoder, + type_decoders = decoder, + + -- special treatment for tagged values + tagged_decoders = tagged_decoders, + + -- constructors for annotated types + simple = simple, + tagged = tagged, + + -- pre-defined simple values + null = null, + undefined = undefined, +}; diff --git a/lua/server/rpc/dispatchers.lua b/lua/server/rpc/dispatchers.lua index b798e4a..ec23834 100644 --- a/lua/server/rpc/dispatchers.lua +++ b/lua/server/rpc/dispatchers.lua @@ -25,16 +25,18 @@ local resumeRoom = function(params) end local handleRequest = function(params) + print(json.encode(params)) if type(params[1]) ~= "string" then return false, nil end local ok, ret = pcall(HandleRequest, params[1]) - if not ok then return false, 'internal_error' end + if not ok then return false, 'internal_error', 'internal_error' end return true, ret end local setPlayerState = function(params) + print(json.encode(params)) if not (type(params[1]) == "number" and type(params[2]) == "number" and type(params[3] == "number")) then return false, nil end @@ -45,7 +47,7 @@ local setPlayerState = function(params) local room = GetRoom(roomId) if not room then - return false, "Room not found" + return false, "internal_error", "Room not found" end for _, p in ipairs(room.room:getPlayers()) do @@ -55,7 +57,7 @@ local setPlayerState = function(params) end end - return false, "Player not found" + return false, "internal_error", "Player not found" end local addObserver = function(params) @@ -68,7 +70,7 @@ local addObserver = function(params) local room = GetRoom(roomId) if not room then - return false, "Room not found" + return false, "internal_error", "Room not found" end table.insert(room.room:getObservers(), fk.ServerPlayer(obj)) @@ -85,7 +87,7 @@ local removeObserver = function(params) local room = GetRoom(roomId) if not room then - return false, "Room not found" + return false, "internal_error", "Room not found" end local observers = room.room:getObservers() @@ -96,7 +98,7 @@ local removeObserver = function(params) end end - return false, "Player not found" + return false, "internal_error", "Player not found" end ---@type table diff --git a/lua/server/rpc/entry.lua b/lua/server/rpc/entry.lua index a521776..b054443 100644 --- a/lua/server/rpc/entry.lua +++ b/lua/server/rpc/entry.lua @@ -5,10 +5,13 @@ package.path = package.path .. "./?.lua;./?/init.lua;./lua/lib/?.lua;./lua/?.lua;./lua/?/init.lua" +local os = os fk = require "server.rpc.fk" +local RPC_MODE = os.getenv("FK_RPC_MODE") == "cbor" and "cbor" or "json" local jsonrpc = require "server.rpc.jsonrpc" local stdio = require "server.rpc.stdio" local dispatchers = require "server.rpc.dispatchers" +local cbor = require 'lua.server.rpc.cbor' -- 加载新月杀相关内容并ban掉两个吃stdin的 dofile "lua/freekill.lua" @@ -49,17 +52,38 @@ coroutine.resume = function(co, ...) return coresume(co, ...) end -local mainLoop = function() - InitScheduler(fk.RoomThread()) - stdio.send(jsonrpc.encode_rpc(jsonrpc.notification, "hello", { "world" })) +local mainLoop +if RPC_MODE == "json" then + mainLoop = function() + InitScheduler(fk.RoomThread()) + stdio.send(jsonrpc.encode_rpc(jsonrpc.notification, "hello", { "world" })) - while true do - local msg = stdio.receive() - if msg == nil then break end + while true do + local msg = stdio.receive() + if msg == nil then break end - local res = jsonrpc.server_response(dispatchers, msg) - if res then - stdio.send(json.encode(res)) + local res = jsonrpc.server_response(dispatchers, msg) + if res then + stdio.send(json.encode(res)) + end + end + end +elseif RPC_MODE == "cbor" then + mainLoop = function() + InitScheduler(fk.RoomThread()) + stdio.stdout:write(jsonrpc.encode_rpc(jsonrpc.notification, "hello", { "world" })) + stdio.stdout:flush() + + while true do + if fk._rpc_finished then break end + local msg = cbor.decode_file(stdio.stdin) + if msg == nil then break end + + local res = jsonrpc.server_response(dispatchers, msg) + if res then + stdio.stdout:write(cbor.encode(res)) + stdio.stdout:flush() + end end end end diff --git a/lua/server/rpc/fk.lua b/lua/server/rpc/fk.lua index 1c2d80a..630b727 100644 --- a/lua/server/rpc/fk.lua +++ b/lua/server/rpc/fk.lua @@ -7,6 +7,9 @@ local os = os +local RPC_MODE = os.getenv("FK_RPC_MODE") == "cbor" and "cbor" or "json" +local cbor = require 'lua.server.rpc.cbor' + -- 下面俩是系统上要安装的 freekill不提供 -- 需安装lua-posix包 @@ -26,29 +29,45 @@ local dispatchers = require "server.rpc.dispatchers" local notifyRpc = function(method, params) local req = jsonrpc.notification(method, params) - stdio.send(json.encode(req)) + if RPC_MODE == "json" then + stdio.send(json.encode(req)) + else + stdio.stdout:write(cbor.encode(req)) + stdio.stdout:flush() + end end local callRpc = function(method, params) local req = jsonrpc.request(method, params) local id = req.id - stdio.send(json.encode(req)) + if RPC_MODE == "json" then + stdio.send(json.encode(req)) + else + stdio.stdout:write(cbor.encode(req)) + stdio.stdout:flush() + end while true do - local msg = stdio.receive() - if msg == nil then break end - - local ok, packet = pcall(json.decode, msg) - if not ok then - stdio.send(json.encode(jsonrpc.response_error(req, 'parse_error', packet))) - goto continue + local msg, packet + if RPC_MODE == "json" then + msg = stdio.receive() + if msg == nil then break end + + local ok + ok, packet = pcall(json.decode, msg) + if not ok then + stdio.send(json.encode(jsonrpc.response_error(req, 'parse_error', packet))) + goto continue + end + else + packet = cbor.decode_file(stdio.stdin) + if packet == nil then break end end - if packet.jsonrpc == "2.0" and packet.id == id and packet.method == nil then - ---@cast packet JsonRpcPacket - return packet.result, packet.error - - elseif packet.error then + if packet[jsonrpc.key_jsonrpc] == "2.0" and + packet[jsonrpc.key_id] == id and packet[jsonrpc.key_method] == nil then + return packet[jsonrpc.key_result], packet[jsonrpc.key_error] + elseif packet[jsonrpc.key_error] then -- 和Json RPC spec不合的一集,我们可能收到预期之外的error -- 这可能是我io编程不达标导致的 -- 对于这种id不合的error包扔了 @@ -56,7 +75,12 @@ local callRpc = function(method, params) else local res = jsonrpc.server_response(dispatchers, packet) if res then - stdio.send(json.encode(res)) + if RPC_MODE == "json" then + stdio.send(json.encode(res)) + else + stdio.stdout:write(cbor.encode(res)) + stdio.stdout:flush() + end end end @@ -80,7 +104,7 @@ fk.QList = function(arr) return setmetatable(arr, { __index = { at = function(self, i) - return self[i+1] + return self[i + 1] end, length = function(self) return #self diff --git a/lua/server/rpc/jsonrpc.lua b/lua/server/rpc/jsonrpc.lua index 25ab5c1..a61534f 100644 --- a/lua/server/rpc/jsonrpc.lua +++ b/lua/server/rpc/jsonrpc.lua @@ -15,7 +15,36 @@ ---@field message string ---@field data? any +local os = os local json = require 'json' +local cbor = require 'lua.server.rpc.cbor' +local RPC_MODE = os.getenv("FK_RPC_MODE") == "cbor" and "cbor" or "json" + +local key_jsonrpc, key_method, key_params, key_error, key_id, key_result, +key_error_code, key_error_message, key_error_data +if RPC_MODE == 'json' then + key_jsonrpc = "jsonrpc" + key_method = "method" + key_params = "params" + key_error = "error" + key_id = "id" + key_result = "result" + + key_error_code = "code" + key_error_message = "message" + key_error_data = "data" +elseif RPC_MODE == 'cbor' then + key_jsonrpc = 100 + key_method = 101 + key_params = 102 + key_error = 103 + key_id = 104 + key_result = 105 + + key_error_code = 200 + key_error_message = 201 + key_error_data = 202 +end -- Standard error objects which can be extended with user defined error objects -- having error codes from -32000 to -32099. Helper functions add_error_object @@ -99,7 +128,11 @@ local function encode_rpc(func, method, params, id) error("Function cannot be found") end local obj = func(method, params, id) - return json.encode(obj) + if RPC_MODE == "json" then + return json.encode(obj) + else + return cbor.encode(obj) + end end ---@return JsonRpcPacket @@ -108,9 +141,9 @@ local function notification(method, params) error("A method in an RPC cannot be empty") end local req = { - jsonrpc = "2.0", - method = method, - params = params, + [key_jsonrpc] = "2.0", + [key_method] = method, + [key_params] = params, } return req end @@ -118,9 +151,9 @@ end ---@return JsonRpcPacket local function request(method, params, id) local req = notification(method, params) - req.id = id or _reqId - if req.id == _reqId then - _reqId = req.id + 1 + req[key_id] = id or _reqId + if req[key_id] == _reqId then + _reqId = req[key_id] + 1 if _reqId > 10000000 then _reqId = 1 end end return req @@ -130,9 +163,9 @@ end ---@return JsonRpcPacket local function response(req, results) return { - jsonrpc = "2.0", - id = req.id, - result = results, + [key_jsonrpc] = "2.0", + [key_id] = req[key_id], + [key_result] = results, } end @@ -145,18 +178,18 @@ local function response_error(req, error_name, data) get_error_object('internal_error') ---@cast error_object -nil - res.jsonrpc = "2.0" - res.error = { - code = error_object.code, - message = error_object.message, + res[key_jsonrpc] = "2.0" + res[key_error] = { + [key_error_code] = error_object.code, + [key_error_message] = error_object.message, } - res.data = data + res[key_error][key_error_data] = data if (error_object.code == -32700) or (error_object.code == -32600) then - res.id = nil --json.util.null() + res[key_id] = nil --json.util.null() else - res.id = req.id + res[key_id] = req[key_id] end return res end @@ -167,20 +200,20 @@ local function handle_request(methods, req) if type(req) ~= 'table' then return response_error(req, 'invalid_request', req) end - if type(req.method) ~= 'string' then + if type(req[key_method]) ~= 'string' then return response_error(req, 'invalid_request', req) end - if type(req.id) ~= 'number' or req.id <= 0 then + if type(req[key_id]) ~= 'number' or req[key_id] <= 0 then return response_error(req, 'invalid_request', req) end - local fnc = methods[req.method] + local fnc = methods[req[key_method]] -- Method not found if type(fnc) ~= 'function' then return response_error(req, 'method_not_found') end - local params = req.params + local params = req[key_params] if params == nil then params = {} end @@ -213,7 +246,7 @@ local function handle_request(methods, req) end -- Notification only - if not req.id then + if not req[key_id] then return nil end @@ -246,7 +279,7 @@ local function server_response(methods, request) end ---@cast req -string - if (#req == 0) and (req.jsonrpc == "2.0") then + if (#req == 0) and (req[key_jsonrpc] == "2.0") then return handle_request(methods, req) elseif #req == 0 then return response_error(req, 'invalid_request', req) @@ -281,4 +314,14 @@ M.get_next_free_id = get_next_free_id M.server_response = server_response +M.key_jsonrpc = key_jsonrpc +M.key_method = key_method +M.key_params = key_params +M.key_error = key_error +M.key_id = key_id +M.key_result = key_result +M.key_error_code = key_error_code +M.key_error_message = key_error_message +M.key_error_data = key_error_data + return M diff --git a/lua/server/rpc/stdio.lua b/lua/server/rpc/stdio.lua index cb8b907..a9534ba 100644 --- a/lua/server/rpc/stdio.lua +++ b/lua/server/rpc/stdio.lua @@ -22,5 +22,8 @@ end local M = {} M.receive = receive M.send = send +-- 提供给cbor的接口,它需要xxx:read() +M.stdin = io.input() +M.stdout = io.output() return M -- Gitee From 6c0a353c83a37c75b10ec1d19325370aebf76219 Mon Sep 17 00:00:00 2001 From: notify Date: Tue, 22 Jul 2025 14:33:16 +0800 Subject: [PATCH 2/4] =?UTF-8?q?cbor=E7=AA=81=E7=84=B6=E5=B0=B1=E8=83=BD?= =?UTF-8?q?=E7=94=A8=E4=BA=86=EF=BC=9F=EF=BC=88notification=E7=9A=84io?= =?UTF-8?q?=E6=97=A0=E6=B3=95=E6=88=98=E8=83=9C=EF=BC=8C=E5=B9=B2=E8=84=86?= =?UTF-8?q?=E5=85=A8=E6=9D=80=E4=BA=86=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lua/server/rpc/dispatchers.lua | 2 -- lua/server/rpc/fk.lua | 24 +++++++----------------- 2 files changed, 7 insertions(+), 19 deletions(-) diff --git a/lua/server/rpc/dispatchers.lua b/lua/server/rpc/dispatchers.lua index ec23834..7f8e84b 100644 --- a/lua/server/rpc/dispatchers.lua +++ b/lua/server/rpc/dispatchers.lua @@ -25,7 +25,6 @@ local resumeRoom = function(params) end local handleRequest = function(params) - print(json.encode(params)) if type(params[1]) ~= "string" then return false, nil end @@ -36,7 +35,6 @@ local handleRequest = function(params) end local setPlayerState = function(params) - print(json.encode(params)) if not (type(params[1]) == "number" and type(params[2]) == "number" and type(params[3] == "number")) then return false, nil end diff --git a/lua/server/rpc/fk.lua b/lua/server/rpc/fk.lua index 630b727..9071969 100644 --- a/lua/server/rpc/fk.lua +++ b/lua/server/rpc/fk.lua @@ -27,19 +27,9 @@ local jsonrpc = require "server.rpc.jsonrpc" local stdio = require "server.rpc.stdio" local dispatchers = require "server.rpc.dispatchers" -local notifyRpc = function(method, params) - local req = jsonrpc.notification(method, params) - if RPC_MODE == "json" then - stdio.send(json.encode(req)) - else - stdio.stdout:write(cbor.encode(req)) - stdio.stdout:flush() - end -end - local callRpc = function(method, params) local req = jsonrpc.request(method, params) - local id = req.id + local id = req[jsonrpc.key_id] if RPC_MODE == "json" then stdio.send(json.encode(req)) else @@ -65,7 +55,7 @@ local callRpc = function(method, params) end if packet[jsonrpc.key_jsonrpc] == "2.0" and - packet[jsonrpc.key_id] == id and packet[jsonrpc.key_method] == nil then + packet[jsonrpc.key_id] == id and type(packet[jsonrpc.key_method]) ~= "string" then return packet[jsonrpc.key_result], packet[jsonrpc.key_error] elseif packet[jsonrpc.key_error] then -- 和Json RPC spec不合的一集,我们可能收到预期之外的error @@ -123,19 +113,19 @@ end fk.QRandomGenerator = qrandom.new function fk.qDebug(fmt, ...) - notifyRpc("qDebug", { string.format(fmt, ...) }) + callRpc("qDebug", { string.format(fmt, ...) }) end function fk.qInfo(fmt, ...) - notifyRpc("qInfo", { string.format(fmt, ...) }) + callRpc("qInfo", { string.format(fmt, ...) }) end function fk.qWarning(fmt, ...) - notifyRpc("qWarning", { string.format(fmt, ...) }) + callRpc("qWarning", { string.format(fmt, ...) }) end function fk.qCritical(fmt, ...) - notifyRpc("qCritical", { string.format(fmt, ...) }) + callRpc("qCritical", { string.format(fmt, ...) }) end -- 连print也要?! @@ -147,7 +137,7 @@ function print(...) for i = 1, n do table.insert(params, tostring(args[i])) end - notifyRpc("print", params) + callRpc("print", params) end -- swig/player.i -- Gitee From 6d35155ce3b3fab92681e9b6a1392ffcaa8ca279 Mon Sep 17 00:00:00 2001 From: notify Date: Tue, 22 Jul 2025 19:49:42 +0800 Subject: [PATCH 3/4] =?UTF-8?q?=E7=94=B1=E4=BA=8Eio.read(0)=E4=BC=9A?= =?UTF-8?q?=E9=98=BB=E5=A1=9E=EF=BC=8C=E5=A5=97=E5=A3=B3=E4=B9=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lua/server/rpc/stdio.lua | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/lua/server/rpc/stdio.lua b/lua/server/rpc/stdio.lua index a9534ba..680c885 100644 --- a/lua/server/rpc/stdio.lua +++ b/lua/server/rpc/stdio.lua @@ -23,7 +23,14 @@ local M = {} M.receive = receive M.send = send -- 提供给cbor的接口,它需要xxx:read() -M.stdin = io.input() +M.stdin = { + read = function(_, n) + -- 没想到io.input():read(0)的情况下依然会等stdin有数据可读才返回 + -- 我们用的cbor并不希望这种情况发生 + if n == 0 then return nil end + return io.read(n) + end, +} M.stdout = io.output() return M -- Gitee From 23474af4d06afac70f32edec3d80cddfaa2a33c5 Mon Sep 17 00:00:00 2001 From: notify Date: Tue, 22 Jul 2025 22:38:32 +0800 Subject: [PATCH 4/4] =?UTF-8?q?io.read(0)=E6=8C=89=E7=85=A7=E5=BA=93?= =?UTF-8?q?=E5=BA=94=E8=AF=A5=E8=BF=94=E5=9B=9E=E7=A9=BA=E4=B8=B2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lua/server/rpc/stdio.lua | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lua/server/rpc/stdio.lua b/lua/server/rpc/stdio.lua index 680c885..25ae010 100644 --- a/lua/server/rpc/stdio.lua +++ b/lua/server/rpc/stdio.lua @@ -25,9 +25,10 @@ M.send = send -- 提供给cbor的接口,它需要xxx:read() M.stdin = { read = function(_, n) + if fk._rpc_finished then return "" end -- 没想到io.input():read(0)的情况下依然会等stdin有数据可读才返回 -- 我们用的cbor并不希望这种情况发生 - if n == 0 then return nil end + if n == 0 then return "" end return io.read(n) end, } -- Gitee