From f07cbae7c5272cd6eafae09b8172e448c4156094 Mon Sep 17 00:00:00 2001 From: Tiga Ultraman Date: Fri, 26 Jul 2024 17:44:33 +0800 Subject: [PATCH 1/8] fix http2 local cache and cpu oversize Signed-off-by: Tiga Ultraman --- ylong_http/src/h2/error.rs | 4 +- ylong_http_client/src/async_impl/client.rs | 18 +- .../src/async_impl/conn/http2.rs | 20 +- ylong_http_client/src/lib.rs | 12 +- ylong_http_client/src/util/config/http.rs | 10 + ylong_http_client/src/util/dispatcher.rs | 137 +++++- ylong_http_client/src/util/h2/manager.rs | 405 ++++++++++++------ ylong_http_client/src/util/h2/mod.rs | 2 +- ylong_http_client/src/util/h2/output.rs | 192 +++++++-- ylong_http_client/src/util/h2/streams.rs | 48 ++- 10 files changed, 623 insertions(+), 225 deletions(-) diff --git a/ylong_http/src/h2/error.rs b/ylong_http/src/h2/error.rs index 99a6033..8ec3279 100644 --- a/ylong_http/src/h2/error.rs +++ b/ylong_http/src/h2/error.rs @@ -29,7 +29,7 @@ use std::convert::{Infallible, TryFrom}; use crate::error::{ErrorKind, HttpError}; /// The http2 error handle implementation -#[derive(Debug, Eq, PartialEq, Clone)] +#[derive(Debug, Eq, PartialEq, Copy, Clone)] pub enum H2Error { /// [`Stream Error`] Handling. /// @@ -45,7 +45,7 @@ pub enum H2Error { /// [`Error Codes`] implementation. /// /// [`Error Codes`]: https://httpwg.org/specs/rfc9113.html#ErrorCodes -#[derive(Debug, Eq, PartialEq, Clone)] +#[derive(Debug, Eq, PartialEq, Copy, Clone)] pub enum ErrorCode { /// The associated condition is not a result of an error. For example, /// a `GOAWAY` might include this code to indicate graceful shutdown of a diff --git a/ylong_http_client/src/async_impl/client.rs b/ylong_http_client/src/async_impl/client.rs index b894a32..658e3b3 100644 --- a/ylong_http_client/src/async_impl/client.rs +++ b/ylong_http_client/src/async_impl/client.rs @@ -498,6 +498,20 @@ impl ClientBuilder { self } + /// Sets allowed max size of local cached frame. + /// + /// # Examples + /// + /// ``` + /// use ylong_http_client::async_impl::ClientBuilder; + /// + /// let config = ClientBuilder::new().allow_cached_frame_num(10); + /// ``` + pub fn allow_cached_frame_num(mut self, num: usize) -> Self { + self.http.http2_config.set_allow_cached_frame_num(num); + self + } + /// Sets the `SETTINGS_MAX_FRAME_SIZE`. /// /// # Examples @@ -575,7 +589,7 @@ impl ClientBuilder { impl ClientBuilder { /// Sets the maximum allowed TLS version for connections. /// - /// By default there's no maximum. + /// By default, there's no maximum. /// /// # Examples /// @@ -592,7 +606,7 @@ impl ClientBuilder { /// Sets the minimum required TLS version for connections. /// - /// By default the TLS backend's own default is used. + /// By default, the TLS backend's own default is used. /// /// # Examples /// diff --git a/ylong_http_client/src/async_impl/conn/http2.rs b/ylong_http_client/src/async_impl/conn/http2.rs index 23cb3ef..4dd2457 100644 --- a/ylong_http_client/src/async_impl/conn/http2.rs +++ b/ylong_http_client/src/async_impl/conn/http2.rs @@ -453,10 +453,10 @@ mod ut_http2 { use crate::async_impl::conn::http2::TextIo; use crate::util::dispatcher::http2::Http2Conn; - let (resp_tx, resp_rx) = crate::runtime::unbounded_channel(); + let (resp_tx, resp_rx) = ylong_runtime::sync::mpsc::bounded_channel(20); let (req_tx, _req_rx) = crate::runtime::unbounded_channel(); let shutdown = Arc::new(AtomicBool::new(false)); - let mut conn: Http2Conn<()> = Http2Conn::new(1, shutdown, req_tx); + let mut conn: Http2Conn<()> = Http2Conn::new(1, 20, shutdown, req_tx); conn.receiver.set_receiver(resp_rx); let mut text_io = TextIo::new(conn); let data_1 = Frame::new( @@ -474,9 +474,19 @@ mod ut_http2 { FrameFlags::new(1), Payload::Data(Data::new(vec![b'a'; 10])), ); - let _ = resp_tx.send(crate::util::dispatcher::http2::RespMessage::Output(data_1)); - let _ = resp_tx.send(crate::util::dispatcher::http2::RespMessage::Output(data_2)); - let _ = resp_tx.send(crate::util::dispatcher::http2::RespMessage::Output(data_3)); + + ylong_runtime::block_on(async { + let _ = resp_tx + .send(crate::util::dispatcher::http2::RespMessage::Output(data_1)) + .await; + let _ = resp_tx + .send(crate::util::dispatcher::http2::RespMessage::Output(data_2)) + .await; + let _ = resp_tx + .send(crate::util::dispatcher::http2::RespMessage::Output(data_3)) + .await; + }); + ylong_runtime::block_on(async { let mut buf = [0_u8; 10]; let mut output_vec = vec![]; diff --git a/ylong_http_client/src/lib.rs b/ylong_http_client/src/lib.rs index 0137861..2f687a9 100644 --- a/ylong_http_client/src/lib.rs +++ b/ylong_http_client/src/lib.rs @@ -72,7 +72,11 @@ pub(crate) mod runtime { io::{split, ReadHalf, WriteHalf}, spawn, sync::{ - mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}, + mpsc::{ + channel as bounded_channel, error::SendError, unbounded_channel, + Receiver as BoundedReceiver, Sender as BoundedSender, UnboundedReceiver, + UnboundedSender, + }, Mutex as AsyncMutex, MutexGuard, }, task::JoinHandle, @@ -94,7 +98,11 @@ pub(crate) mod runtime { pub(crate) use ylong_runtime::{ spawn, sync::{ - mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}, + error::SendError, + mpsc::{ + bounded_channel, unbounded_channel, BoundedReceiver, BoundedSender, + UnboundedReceiver, UnboundedSender, + }, Mutex as AsyncMutex, MutexGuard, }, task::JoinHandle, diff --git a/ylong_http_client/src/util/config/http.rs b/ylong_http_client/src/util/config/http.rs index ea6f3b9..2a237b6 100644 --- a/ylong_http_client/src/util/config/http.rs +++ b/ylong_http_client/src/util/config/http.rs @@ -75,6 +75,7 @@ pub(crate) mod http2 { init_conn_window_size: u32, init_stream_window_size: u32, enable_push: bool, + allow_cached_frame_num: usize, } impl H2Config { @@ -106,6 +107,10 @@ pub(crate) mod http2 { self.init_stream_window_size = size; } + pub(crate) fn set_allow_cached_frame_num(&mut self, num: usize) { + self.allow_cached_frame_num = num; + } + /// Gets the SETTINGS_MAX_FRAME_SIZE. pub(crate) fn max_frame_size(&self) -> u32 { self.max_frame_size @@ -132,6 +137,10 @@ pub(crate) mod http2 { pub(crate) fn stream_window_size(&self) -> u32 { self.init_stream_window_size } + + pub(crate) fn allow_cached_frame_num(&self) -> usize { + self.allow_cached_frame_num + } } impl Default for H2Config { @@ -143,6 +152,7 @@ pub(crate) mod http2 { init_conn_window_size: DEFAULT_CONN_WINDOW_SIZE, init_stream_window_size: DEFAULT_STREAM_WINDOW_SIZE, enable_push: false, + allow_cached_frame_num: 5, } } } diff --git a/ylong_http_client/src/util/dispatcher.rs b/ylong_http_client/src/util/dispatcher.rs index c7c23a5..ab5028d 100644 --- a/ylong_http_client/src/util/dispatcher.rs +++ b/ylong_http_client/src/util/dispatcher.rs @@ -148,6 +148,7 @@ pub(crate) mod http1 { #[cfg(feature = "http2")] pub(crate) mod http2 { use std::collections::HashMap; + use std::future::Future; use std::marker::PhantomData; use std::pin::Pin; use std::sync::atomic::{AtomicBool, AtomicU32, Ordering}; @@ -157,16 +158,19 @@ pub(crate) mod http2 { use ylong_http::error::HttpError; use ylong_http::h2::{ ErrorCode, Frame, FrameDecoder, FrameEncoder, FrameFlags, Goaway, H2Error, Payload, - Settings, SettingsBuilder, + RstStream, Settings, SettingsBuilder, }; use crate::runtime::{ - unbounded_channel, AsyncRead, AsyncWrite, AsyncWriteExt, UnboundedReceiver, - UnboundedSender, WriteHalf, + bounded_channel, unbounded_channel, AsyncRead, AsyncWrite, AsyncWriteExt, BoundedReceiver, + BoundedSender, SendError, UnboundedReceiver, UnboundedSender, WriteHalf, }; use crate::util::config::H2Config; use crate::util::dispatcher::{ConnDispatcher, Dispatcher}; - use crate::util::h2::{ConnManager, FlowControl, RecvData, RequestWrapper, SendData, Streams}; + use crate::util::h2::{ + ConnManager, FlowControl, H2StreamState, RecvData, RequestWrapper, SendData, + StreamEndState, Streams, + }; use crate::ErrorKind::Request; use crate::{ErrorKind, HttpClientError}; @@ -175,6 +179,9 @@ pub(crate) mod http2 { const DEFAULT_MAX_HEADER_LIST_SIZE: usize = 16 << 20; const DEFAULT_WINDOW_SIZE: u32 = 65535; + pub(crate) type ManagerSendFut = + Pin>> + Send + Sync>>; + pub(crate) enum RespMessage { Output(Frame), OutputExit(DispatchErrorKind), @@ -187,11 +194,11 @@ pub(crate) mod http2 { pub(crate) struct ReqMessage { pub(crate) id: u32, - pub(crate) sender: UnboundedSender, + pub(crate) sender: BoundedSender, pub(crate) request: RequestWrapper, } - #[derive(Debug, Eq, PartialEq, Clone)] + #[derive(Debug, Eq, PartialEq, Copy, Clone)] pub(crate) enum DispatchErrorKind { H2(H2Error), Io(std::io::ErrorKind), @@ -203,6 +210,7 @@ pub(crate) mod http2 { // threads according to HTTP2 syntax. pub(crate) struct Http2Dispatcher { pub(crate) next_stream_id: StreamId, + pub(crate) allow_cached_frames: usize, pub(crate) sender: UnboundedSender, pub(crate) io_shutdown: Arc, pub(crate) handles: Vec>, @@ -212,6 +220,7 @@ pub(crate) mod http2 { pub(crate) struct Http2Conn { // Handle id pub(crate) id: u32, + pub(crate) allow_cached_frames: usize, // Sends frame to StreamController pub(crate) sender: UnboundedSender, pub(crate) receiver: RespReceiver, @@ -224,11 +233,12 @@ pub(crate) mod http2 { // closed. pub(crate) io_shutdown: Arc, // The senders of all connected stream channels of response. - pub(crate) senders: HashMap>, + pub(crate) senders: HashMap>, + pub(crate) curr_message: HashMap, // Stream information on the connection. pub(crate) streams: Streams, // Received GO_AWAY frame. - pub(crate) go_away: Option, + pub(crate) recved_go_away: Option, // The last GO_AWAY frame sent by the client. pub(crate) go_away_sync: GoAwaySync, } @@ -257,7 +267,7 @@ pub(crate) mod http2 { #[derive(Default)] pub(crate) struct RespReceiver { - receiver: Option>, + receiver: Option>, } impl ConnDispatcher @@ -294,10 +304,19 @@ pub(crate) mod http2 { // being. let mut handles = Vec::with_capacity(3); if input_tx.send(settings).is_ok() { - Self::launch(controller, req_rx, input_tx, input_rx, &mut handles, io); + Self::launch( + config.allow_cached_frame_num(), + controller, + req_rx, + input_tx, + input_rx, + &mut handles, + io, + ); } Self { next_stream_id, + allow_cached_frames: config.allow_cached_frame_num(), sender: req_tx, io_shutdown: shutdown_flag, handles, @@ -306,6 +325,7 @@ pub(crate) mod http2 { } fn launch( + allow_num: usize, controller: StreamController, req_rx: UnboundedReceiver, input_tx: UnboundedSender, @@ -313,7 +333,7 @@ pub(crate) mod http2 { handles: &mut Vec>, io: S, ) { - let (resp_tx, resp_rx) = unbounded_channel(); + let (resp_tx, resp_rx) = bounded_channel(allow_num); let (read, write) = crate::runtime::split(io); let settings_sync = Arc::new(Mutex::new(SettingsSync::default())); let send_settings_sync = settings_sync.clone(); @@ -339,9 +359,7 @@ pub(crate) mod http2 { let manager = crate::runtime::spawn(async move { let mut conn_manager = ConnManager::new(settings_sync, input_tx, resp_rx, req_rx, controller); - if let Err(e) = Pin::new(&mut conn_manager).await { - conn_manager.exit_with_error(e); - } + let _ = Pin::new(&mut conn_manager).await; }); handles.push(manager); } @@ -356,7 +374,12 @@ pub(crate) mod http2 { return None; } let sender = self.sender.clone(); - let handle = Http2Conn::new(id, self.io_shutdown.clone(), sender); + let handle = Http2Conn::new( + id, + self.allow_cached_frames, + self.io_shutdown.clone(), + sender, + ); Some(handle) } @@ -379,11 +402,13 @@ pub(crate) mod http2 { impl Http2Conn { pub(crate) fn new( id: u32, + allow_cached_num: usize, io_shutdown: Arc, sender: UnboundedSender, ) -> Self { Self { id, + allow_cached_frames: allow_cached_num, sender, receiver: RespReceiver::default(), io_shutdown, @@ -395,7 +420,7 @@ pub(crate) mod http2 { &mut self, request: RequestWrapper, ) -> Result<(), HttpClientError> { - let (tx, rx) = unbounded_channel::(); + let (tx, rx) = bounded_channel::(self.allow_cached_frames); self.receiver.set_receiver(rx); self.sender .send(ReqMessage { @@ -420,8 +445,9 @@ pub(crate) mod http2 { Self { io_shutdown: shutdown, senders: HashMap::new(), + curr_message: HashMap::new(), streams, - go_away: None, + recved_go_away: None, go_away_sync: GoAwaySync::default(), } } @@ -430,7 +456,7 @@ pub(crate) mod http2 { self.io_shutdown.store(true, Ordering::Release); } - pub(crate) fn go_away_unsent_stream( + pub(crate) fn get_unsent_streams( &mut self, last_stream_id: u32, ) -> Result, H2Error> { @@ -443,21 +469,86 @@ pub(crate) mod http2 { Ok(self.streams.get_go_away_streams(last_stream_id)) } - pub(crate) fn send_message_to_stream(&mut self, stream_id: u32, message: RespMessage) { + pub(crate) fn send_message_to_stream( + &mut self, + cx: &mut Context<'_>, + stream_id: u32, + message: RespMessage, + ) -> Poll> { if let Some(sender) = self.senders.get(&stream_id) { // If the client coroutine has exited, this frame is skipped. - match sender.send(message) { - Ok(_) => {} - Err(_e) => { + let mut tx = { + let sender = sender.clone(); + let ft = async move { sender.send(message).await }; + Box::pin(ft) + }; + + match tx.as_mut().poll(cx) { + Poll::Ready(Ok(_)) => Poll::Ready(Ok(())), + // The current coroutine sending the request exited prematurely. + Poll::Ready(Err(_)) => { self.senders.remove(&stream_id); + Poll::Ready(Err(H2Error::StreamError(stream_id, ErrorCode::NoError))) + } + Poll::Pending => { + self.curr_message.insert(stream_id, tx); + Poll::Pending + } + } + } else { + Poll::Ready(Err(H2Error::StreamError(stream_id, ErrorCode::NoError))) + } + } + + pub(crate) fn poll_blocked_message( + &mut self, + cx: &mut Context<'_>, + input_tx: &UnboundedSender, + ) -> Poll<()> { + let keys: Vec = self.curr_message.keys().cloned().collect(); + let mut blocked = false; + + for key in keys { + if let Some(mut task) = self.curr_message.remove(&key) { + match task.as_mut().poll(cx) { + Poll::Ready(Ok(_)) => {} + // The current coroutine sending the request exited prematurely. + Poll::Ready(Err(_)) => { + self.senders.remove(&key); + if let Some(state) = self.streams.stream_state(key) { + if !matches!(state, H2StreamState::Closed(_)) { + if let StreamEndState::OK = self.streams.send_local_reset(key) { + let rest_payload = + RstStream::new(ErrorCode::NoError.into_code()); + let frame = Frame::new( + key as usize, + FrameFlags::empty(), + Payload::RstStream(rest_payload), + ); + // ignore the send error occurs here in order to finish all + // tasks. + let _ = input_tx.send(frame); + } + } + } + } + Poll::Pending => { + self.curr_message.insert(key, task); + blocked = true; + } } } } + if blocked { + Poll::Pending + } else { + Poll::Ready(()) + } } } impl RespReceiver { - pub(crate) fn set_receiver(&mut self, receiver: UnboundedReceiver) { + pub(crate) fn set_receiver(&mut self, receiver: BoundedReceiver) { self.receiver = Some(receiver); } diff --git a/ylong_http_client/src/util/h2/manager.rs b/ylong_http_client/src/util/h2/manager.rs index 5e4bdb5..4b06c63 100644 --- a/ylong_http_client/src/util/h2/manager.rs +++ b/ylong_http_client/src/util/h2/manager.rs @@ -22,20 +22,29 @@ use ylong_http::h2::{ ErrorCode, Frame, FrameFlags, Goaway, H2Error, Payload, Ping, RstStream, Setting, }; -use crate::runtime::{UnboundedReceiver, UnboundedSender}; +use crate::runtime::{BoundedReceiver, UnboundedReceiver, UnboundedSender}; use crate::util::dispatcher::http2::{ DispatchErrorKind, OutputMessage, ReqMessage, RespMessage, SettingsState, SettingsSync, StreamController, }; use crate::util::h2::streams::{DataReadState, FrameRecvState, StreamEndState}; +#[derive(Copy, Clone)] +enum ManagerState { + Send, + Receive, + Exit(DispatchErrorKind), +} + pub(crate) struct ConnManager { + state: ManagerState, + next_state: ManagerState, // Synchronize SETTINGS frames sent by the client. settings: Arc>, // channel transmitter between manager and io input. input_tx: UnboundedSender, // channel receiver between manager and io output. - resp_rx: UnboundedReceiver, + resp_rx: BoundedReceiver, // channel receiver between manager and stream coroutine. req_rx: UnboundedReceiver, controller: StreamController, @@ -47,42 +56,60 @@ impl Future for ConnManager { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let manager = self.get_mut(); loop { - // Receives a response frame from io output. - match manager.resp_rx.poll_recv(cx) { - #[cfg(feature = "tokio_base")] - Poll::Ready(Some(message)) => match message { - OutputMessage::Output(frame) => { - manager.poll_recv_message(frame)?; - } - // io output occurs error. - OutputMessage::OutputExit(e) => { - manager.manage_resp_error(e)?; - } - }, - #[cfg(feature = "ylong_base")] - Poll::Ready(Ok(message)) => match message { - OutputMessage::Output(frame) => { - manager.poll_recv_message(frame)?; + match manager.state { + ManagerState::Send => { + if manager.poll_blocked_frames(cx).is_pending() { + return Poll::Pending; } - // io output occurs error. - OutputMessage::OutputExit(e) => { - manager.manage_resp_error(e)?; - } - }, - #[cfg(feature = "tokio_base")] - Poll::Ready(None) => { - manager.exit_with_error(DispatchErrorKind::ChannelClosed); - return Poll::Ready(Ok(())); - } - #[cfg(feature = "ylong_base")] - Poll::Ready(Err(_e)) => { - manager.exit_with_error(DispatchErrorKind::ChannelClosed); - return Poll::Ready(Ok(())); } - - Poll::Pending => { - return manager.manage_pending_state(cx); + ManagerState::Receive => { + // Receives a response frame from io output. + match manager.resp_rx.poll_recv(cx) { + #[cfg(feature = "tokio_base")] + Poll::Ready(Some(message)) => match message { + OutputMessage::Output(frame) => { + if manager.poll_recv_message(cx, frame)?.is_pending() { + return Poll::Pending; + } + } + // io output occurs error. + OutputMessage::OutputExit(e) => { + // Note error returned immediately. + if manager.manage_resp_error(cx, e)?.is_pending() { + return Poll::Pending; + } + } + }, + #[cfg(feature = "ylong_base")] + Poll::Ready(Ok(message)) => match message { + OutputMessage::Output(frame) => { + if manager.poll_recv_message(cx, frame)?.is_pending() { + return Poll::Pending; + } + } + // io output occurs error. + OutputMessage::OutputExit(e) => { + if manager.manage_resp_error(cx, e)?.is_pending() { + return Poll::Pending; + } + } + }, + #[cfg(feature = "tokio_base")] + Poll::Ready(None) => { + return manager.poll_channel_closed_exit(cx); + } + #[cfg(feature = "ylong_base")] + Poll::Ready(Err(_e)) => { + return manager.poll_channel_closed_exit(cx); + } + + Poll::Pending => { + // TODO manage error state. + return manager.manage_pending_state(cx); + } + } } + ManagerState::Exit(e) => return Poll::Ready(Err(e)), } } } @@ -92,11 +119,13 @@ impl ConnManager { pub(crate) fn new( settings: Arc>, input_tx: UnboundedSender, - resp_rx: UnboundedReceiver, + resp_rx: BoundedReceiver, req_rx: UnboundedReceiver, controller: StreamController, ) -> Self { Self { + state: ManagerState::Receive, + next_state: ManagerState::Receive, settings, input_tx, resp_rx, @@ -110,7 +139,7 @@ impl ConnManager { cx: &mut Context<'_>, ) -> Poll> { // The manager previously accepted a GOAWAY Frame. - if let Some(code) = self.controller.go_away { + if let Some(code) = self.controller.recved_go_away { self.poll_deal_with_go_away(code)?; } self.controller.streams.window_update_conn(&self.input_tx)?; @@ -167,9 +196,11 @@ impl ConnManager { } fn poll_input_request(&mut self, cx: &mut Context<'_>) -> Result<(), DispatchErrorKind> { - loop { - self.controller.streams.try_consume_pending_concurrency(); - match self.controller.streams.next_stream() { + self.controller.streams.try_consume_pending_concurrency(); + let size = self.controller.streams.pending_stream_num(); + let mut index = 0; + while index < size { + match self.controller.streams.next_pending_stream() { None => { break; } @@ -177,6 +208,7 @@ impl ConnManager { self.input_stream_frame(cx, id)?; } } + index += 1; } Ok(()) } @@ -243,7 +275,11 @@ impl ConnManager { .map_err(|_e| DispatchErrorKind::ChannelClosed) } - fn poll_recv_frame(&mut self, frame: Frame) -> Result<(), DispatchErrorKind> { + fn poll_recv_frame( + &mut self, + cx: &mut Context<'_>, + frame: Frame, + ) -> Poll> { match frame.payload() { Payload::Settings(_settings) => { self.recv_settings_frame(frame)?; @@ -253,19 +289,21 @@ impl ConnManager { } Payload::PushPromise(_) => { // TODO The current settings_enable_push setting is fixed to false. - return Err(H2Error::ConnectionError(ErrorCode::ProtocolError).into()); + return Poll::Ready(Err( + H2Error::ConnectionError(ErrorCode::ProtocolError).into() + )); } Payload::Goaway(_go_away) => { - self.recv_go_away_frame(frame)?; + return self.recv_go_away_frame(cx, frame).map_err(Into::into); } Payload::RstStream(_reset) => { - self.recv_reset_frame(frame)?; + return self.recv_reset_frame(cx, frame).map_err(Into::into); } Payload::Headers(_headers) => { - self.recv_header_frame(frame)?; + return self.recv_header_frame(cx, frame).map_err(Into::into); } Payload::Data(_data) => { - self.recv_data_frame(frame)?; + return self.recv_data_frame(cx, frame).map_err(Into::into); } Payload::WindowUpdate(_windows) => { self.recv_window_frame(frame)?; @@ -273,7 +311,7 @@ impl ConnManager { // Priority is no longer recommended, so keep it compatible but not processed. Payload::Priority(_priority) => {} } - Ok(()) + Poll::Ready(Ok(())) } fn recv_settings_frame(&mut self, frame: Frame) -> Result<(), DispatchErrorKind> { @@ -342,93 +380,116 @@ impl ConnManager { } } - fn recv_go_away_frame(&mut self, frame: Frame) -> Result<(), DispatchErrorKind> { + fn recv_go_away_frame( + &mut self, + cx: &mut Context<'_>, + frame: Frame, + ) -> Poll> { let go_away = if let Payload::Goaway(goaway) = frame.payload() { goaway } else { // this will not happen forever. - return Ok(()); + return Poll::Ready(Ok(())); }; // Prevents the current connection from generating a new stream. self.controller.shutdown(); self.req_rx.close(); let last_stream_id = go_away.get_last_stream_id(); - let streams = self - .controller - .go_away_unsent_stream(last_stream_id as u32)?; + let streams = self.controller.get_unsent_streams(last_stream_id as u32)?; let error = H2Error::ConnectionError(ErrorCode::try_from(go_away.get_error_code())?); + + let mut blocked = false; for stream_id in streams { - self.controller - .send_message_to_stream(stream_id, RespMessage::OutputExit(error.clone().into())); + match self.controller.send_message_to_stream( + cx, + stream_id, + RespMessage::OutputExit(error.into()), + ) { + // ignore error when going away. + Poll::Ready(_) => {} + Poll::Pending => { + blocked = true; + } + } } // Exit after the allowed stream is complete. - self.controller.go_away = Some(go_away.get_error_code()); - Ok(()) + self.controller.recved_go_away = Some(go_away.get_error_code()); + if blocked { + Poll::Pending + } else { + Poll::Ready(Ok(())) + } } - fn recv_reset_frame(&mut self, frame: Frame) -> Result<(), DispatchErrorKind> { + fn recv_reset_frame( + &mut self, + cx: &mut Context<'_>, + frame: Frame, + ) -> Poll> { match self .controller .streams .recv_remote_reset(frame.stream_id() as u32) { - StreamEndState::OK => { - self.controller - .send_message_to_stream(frame.stream_id() as u32, RespMessage::Output(frame)); - } - StreamEndState::Err(e) => { - return Err(e.into()); - } - _ => {} + StreamEndState::OK => self.controller.send_message_to_stream( + cx, + frame.stream_id() as u32, + RespMessage::Output(frame), + ), + StreamEndState::Err(e) => Poll::Ready(Err(e)), + StreamEndState::Ignore => Poll::Ready(Ok(())), } - Ok(()) } - fn recv_header_frame(&mut self, frame: Frame) -> Result<(), DispatchErrorKind> { + fn recv_header_frame( + &mut self, + cx: &mut Context<'_>, + frame: Frame, + ) -> Poll> { match self .controller .streams .recv_headers(frame.stream_id() as u32, frame.flags().is_end_stream()) { - FrameRecvState::OK => { - self.controller - .send_message_to_stream(frame.stream_id() as u32, RespMessage::Output(frame)); - } - FrameRecvState::Err(e) => { - return Err(e.into()); - } - _ => {} + FrameRecvState::OK => self.controller.send_message_to_stream( + cx, + frame.stream_id() as u32, + RespMessage::Output(frame), + ), + FrameRecvState::Err(e) => Poll::Ready(Err(e)), + FrameRecvState::Ignore => Poll::Ready(Ok(())), } - Ok(()) } - fn recv_data_frame(&mut self, frame: Frame) -> Result<(), DispatchErrorKind> { + fn recv_data_frame(&mut self, cx: &mut Context<'_>, frame: Frame) -> Poll> { let data = if let Payload::Data(data) = frame.payload() { data } else { // this will not happen forever. - return Ok(()); + return Poll::Ready(Ok(())); }; let id = frame.stream_id() as u32; let len = data.size() as u32; + + self.controller.streams.release_conn_recv_window(len)?; + self.controller + .streams + .release_stream_recv_window(id, len)?; + match self .controller .streams .recv_data(id, frame.flags().is_end_stream()) { - FrameRecvState::OK => { - self.controller - .send_message_to_stream(frame.stream_id() as u32, RespMessage::Output(frame)); - } - FrameRecvState::Ignore => {} - FrameRecvState::Err(e) => return Err(e.into()), + FrameRecvState::OK => self.controller.send_message_to_stream( + cx, + frame.stream_id() as u32, + RespMessage::Output(frame), + ), + FrameRecvState::Ignore => Poll::Ready(Ok(())), + FrameRecvState::Err(e) => Poll::Ready(Err(e)), } - self.controller.streams.release_conn_recv_window(len)?; - self.controller - .streams - .release_stream_recv_window(id, len)?; - Ok(()) } fn recv_window_frame(&mut self, frame: Frame) -> Result<(), DispatchErrorKind> { @@ -453,27 +514,35 @@ impl ConnManager { Ok(()) } - fn manage_resp_error(&mut self, kind: DispatchErrorKind) -> Result<(), DispatchErrorKind> { + fn manage_resp_error( + &mut self, + cx: &mut Context<'_>, + kind: DispatchErrorKind, + ) -> Poll> { match kind { - DispatchErrorKind::H2(h2) => { - match h2 { - H2Error::StreamError(id, code) => { - self.manage_stream_error(id, code)?; - } - H2Error::ConnectionError(code) => { - self.manage_conn_error(code)?; - } - } - Ok(()) - } + DispatchErrorKind::H2(h2) => match h2 { + H2Error::StreamError(id, code) => self.manage_stream_error(cx, id, code), + H2Error::ConnectionError(code) => self.manage_conn_error(cx, code), + }, other => { - self.exit_with_error(other.clone()); - Err(other) + let blocked = self.exit_with_error(cx, other); + if blocked { + self.state = ManagerState::Send; + self.next_state = ManagerState::Exit(other); + Poll::Pending + } else { + Poll::Ready(Err(other)) + } } } } - fn manage_stream_error(&mut self, id: u32, code: ErrorCode) -> Result<(), DispatchErrorKind> { + fn manage_stream_error( + &mut self, + cx: &mut Context<'_>, + id: u32, + code: ErrorCode, + ) -> Poll> { let rest_payload = RstStream::new(code.into_code()); let frame = Frame::new( id as usize, @@ -486,29 +555,42 @@ impl ConnManager { .send(frame) .map_err(|_e| DispatchErrorKind::ChannelClosed)?; - self.controller.send_message_to_stream( + match self.controller.send_message_to_stream( + cx, id, RespMessage::OutputExit(DispatchErrorKind::ChannelClosed), - ); + ) { + Poll::Ready(_) => { + // error at the stream level due to early exit of the coroutine in which the + // request is located, ignored to avoid manager coroutine exit. + Poll::Ready(Ok(())) + } + Poll::Pending => { + self.state = ManagerState::Send; + // stream error will not cause manager exit with error(exit state). Takes + // effect only if blocked. + self.next_state = ManagerState::Receive; + Poll::Pending + } + } } - StreamEndState::Ignore => {} + StreamEndState::Ignore => Poll::Ready(Ok(())), StreamEndState::Err(e) => { // This error will never happen. - return Err(e.into()); + Poll::Ready(Err(e.into())) } } - Ok(()) } - fn manage_conn_error(&mut self, code: ErrorCode) -> Result<(), DispatchErrorKind> { - self.exit_with_error(DispatchErrorKind::H2(H2Error::ConnectionError( - code.clone(), - ))); - - // last_stream_id is set to 0 to ensure that all streams are + fn manage_conn_error( + &mut self, + cx: &mut Context<'_>, + code: ErrorCode, + ) -> Poll> { + // last_stream_id is set to 0 to ensure that all pushed streams are // shutdown. let go_away_payload = Goaway::new( - code.clone().into_code(), + code.into_code(), self.controller.streams.latest_remote_id as usize, vec![], ); @@ -517,22 +599,32 @@ impl ConnManager { FrameFlags::empty(), Payload::Goaway(go_away_payload.clone()), ); + // Avoid sending the same GO_AWAY frame multiple times. if let Some(ref go_away) = self.controller.go_away_sync.going_away { if go_away.get_error_code() == go_away_payload.get_error_code() && go_away.get_last_stream_id() == go_away_payload.get_last_stream_id() { - return Ok(()); + return Poll::Ready(Ok(())); } } - // Avoid sending the same GO_AWAY frame multiple times. self.controller.go_away_sync.going_away = Some(go_away_payload); self.input_tx .send(frame) .map_err(|_e| DispatchErrorKind::ChannelClosed)?; - // TODO When the current client has an error, - // it always sends the GO_AWAY frame at the first time and exits directly. - // Should we consider letting part of the unfinished stream complete? - Err(H2Error::ConnectionError(code).into()) + + let blocked = + self.exit_with_error(cx, DispatchErrorKind::H2(H2Error::ConnectionError(code))); + + if blocked { + self.state = ManagerState::Send; + self.next_state = ManagerState::Exit(H2Error::ConnectionError(code).into()); + Poll::Pending + } else { + // TODO When current client has an error, + // it always sends the GO_AWAY frame at the first time and exits directly. + // Should we consider letting part of the unfinished stream complete? + Poll::Ready(Err(H2Error::ConnectionError(code).into())) + } } fn poll_deal_with_go_away(&mut self, error_code: u32) -> Result<(), DispatchErrorKind> { @@ -583,18 +675,71 @@ impl ConnManager { Ok(()) } - fn poll_recv_message(&mut self, frame: Frame) -> Result<(), DispatchErrorKind> { - if let Err(kind) = self.poll_recv_frame(frame) { - self.manage_resp_error(kind)?; + fn poll_recv_message( + &mut self, + cx: &mut Context<'_>, + frame: Frame, + ) -> Poll> { + match self.poll_recv_frame(cx, frame) { + Poll::Ready(Err(kind)) => self.manage_resp_error(cx, kind), + Poll::Pending => { + self.state = ManagerState::Send; + self.next_state = ManagerState::Receive; + Poll::Pending + } + Poll::Ready(Ok(_)) => Poll::Ready(Ok(())), } - Ok(()) } - pub(crate) fn exit_with_error(&mut self, error: DispatchErrorKind) { + fn poll_channel_closed_exit( + &mut self, + cx: &mut Context<'_>, + ) -> Poll> { + if self.exit_with_error(cx, DispatchErrorKind::ChannelClosed) { + self.state = ManagerState::Send; + self.next_state = ManagerState::Exit(DispatchErrorKind::ChannelClosed); + Poll::Pending + } else { + Poll::Ready(Err(DispatchErrorKind::ChannelClosed)) + } + } + + fn poll_blocked_frames(&mut self, cx: &mut Context<'_>) -> Poll<()> { + match self.controller.poll_blocked_message(cx, &self.input_tx) { + Poll::Ready(_) => { + self.state = self.next_state; + // Reset state. + self.next_state = ManagerState::Receive; + Poll::Ready(()) + } + Poll::Pending => Poll::Pending, + } + } + + pub(crate) fn exit_with_error( + &mut self, + cx: &mut Context<'_>, + error: DispatchErrorKind, + ) -> bool { self.controller.shutdown(); self.req_rx.close(); - self.controller - .streams - .go_away_all_streams(&mut self.controller.senders, error); + self.controller.streams.clear_streams_states(); + + let ids = self.controller.streams.get_all_unclosed_streams(); + let mut blocked = false; + for stream_id in ids { + match self.controller.send_message_to_stream( + cx, + stream_id, + RespMessage::OutputExit(error), + ) { + // ignore error when going away. + Poll::Ready(_) => {} + Poll::Pending => { + blocked = true; + } + } + } + blocked } } diff --git a/ylong_http_client/src/util/h2/mod.rs b/ylong_http_client/src/util/h2/mod.rs index 50feae9..0916468 100644 --- a/ylong_http_client/src/util/h2/mod.rs +++ b/ylong_http_client/src/util/h2/mod.rs @@ -38,6 +38,6 @@ pub(crate) use input::SendData; pub(crate) use io::{split, Reader, Writer}; pub(crate) use manager::ConnManager; pub(crate) use output::RecvData; -pub(crate) use streams::{RequestWrapper, Streams}; +pub(crate) use streams::{H2StreamState, RequestWrapper, StreamEndState, Streams}; pub const MAX_FLOW_CONTROL_WINDOW: u32 = (1 << 31) - 1; diff --git a/ylong_http_client/src/util/h2/output.rs b/ylong_http_client/src/util/h2/output.rs index a6c31d2..366ae2f 100644 --- a/ylong_http_client/src/util/h2/output.rs +++ b/ylong_http_client/src/util/h2/output.rs @@ -19,19 +19,33 @@ use std::sync::{Arc, Mutex}; use std::task::{Context, Poll}; use ylong_http::h2::{ - ErrorCode, Frame, FrameDecoder, FrameKind, Frames, H2Error, Payload, Setting, + ErrorCode, Frame, FrameDecoder, FrameKind, FramesIntoIter, H2Error, Payload, Setting, }; -use crate::runtime::{AsyncRead, ReadBuf, ReadHalf, UnboundedSender}; +use crate::runtime::{AsyncRead, BoundedSender, ReadBuf, ReadHalf, SendError}; use crate::util::dispatcher::http2::{ DispatchErrorKind, OutputMessage, SettingsState, SettingsSync, }; +pub(crate) type OutputSendFut = + Pin>> + Send + Sync>>; + +#[derive(Copy, Clone)] +enum DecodeState { + Read, + Send, + Exit(DispatchErrorKind), +} + pub(crate) struct RecvData { decoder: FrameDecoder, settings: Arc>, reader: ReadHalf, - resp_tx: UnboundedSender, + state: DecodeState, + next_state: DecodeState, + resp_tx: BoundedSender, + curr_message: Option, + pending_iter: Option, } impl Future for RecvData { @@ -48,72 +62,170 @@ impl RecvData { decoder: FrameDecoder, settings: Arc>, reader: ReadHalf, - resp_tx: UnboundedSender, + resp_tx: BoundedSender, ) -> Self { Self { decoder, settings, reader, + state: DecodeState::Read, + next_state: DecodeState::Read, resp_tx, + curr_message: None, + pending_iter: None, } } fn poll_read_frame(&mut self, cx: &mut Context<'_>) -> Poll> { let mut buf = [0u8; 1024]; loop { - let mut read_buf = ReadBuf::new(&mut buf); - match Pin::new(&mut self.reader).poll_read(cx, &mut read_buf) { - Poll::Ready(Err(e)) => { - self.transmit_error(DispatchErrorKind::Disconnect)?; - return Poll::Ready(Err(e.into())); + match self.state { + DecodeState::Read => { + let mut read_buf = ReadBuf::new(&mut buf); + match Pin::new(&mut self.reader).poll_read(cx, &mut read_buf) { + Poll::Ready(Err(e)) => { + return self.transmit_error(cx, e.into()); + } + Poll::Ready(Ok(())) => {} + Poll::Pending => { + return Poll::Pending; + } + } + let read = read_buf.filled().len(); + if read == 0 { + return self.transmit_error(cx, DispatchErrorKind::Disconnect); + } + + match self.decoder.decode(&buf[..read]) { + Ok(frames) => match self.poll_iterator_frames(cx, frames.into_iter()) { + Poll::Ready(Ok(_)) => {} + Poll::Ready(Err(e)) => { + return Poll::Ready(Err(e)); + } + Poll::Pending => { + self.next_state = DecodeState::Read; + } + }, + Err(e) => { + match self.transmit_message(cx, OutputMessage::OutputExit(e.into())) { + Poll::Ready(Err(_)) => { + return Poll::Ready(Err(DispatchErrorKind::ChannelClosed)) + } + Poll::Ready(Ok(_)) => {} + Poll::Pending => { + self.next_state = DecodeState::Read; + return Poll::Pending; + } + } + } + } } - Poll::Ready(Ok(())) => {} - Poll::Pending => { - return Poll::Pending; + DecodeState::Send => { + match self.poll_blocked_task(cx) { + Poll::Ready(Ok(_)) => { + self.state = self.next_state; + // Reset next state. + self.next_state = DecodeState::Read; + } + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + Poll::Pending => return Poll::Pending, + } + } + DecodeState::Exit(e) => { + return Poll::Ready(Err(e)); } } - let read = read_buf.filled().len(); - if read == 0 { - self.transmit_error(DispatchErrorKind::Disconnect)?; - return Poll::Ready(Err(DispatchErrorKind::Disconnect)); - } + } + } - match self.decoder.decode(&buf[..read]) { - Ok(frames) => match self.transmit_frame(frames) { - Ok(_) => {} - Err(DispatchErrorKind::H2(e)) => { - self.transmit_error(e.into())?; - } - Err(e) => { - return Poll::Ready(Err(e)); - } - }, - Err(e) => { - self.transmit_error(e.into())?; + fn poll_blocked_task(&mut self, cx: &mut Context<'_>) -> Poll> { + if let Some(mut task) = self.curr_message.take() { + match task.as_mut().poll(cx) { + Poll::Ready(Ok(_)) => {} + Poll::Ready(Err(_)) => { + return Poll::Ready(Err(DispatchErrorKind::ChannelClosed)); + } + Poll::Pending => { + self.curr_message = Some(task); + return Poll::Pending; } } } + + if let Some(iter) = self.pending_iter.take() { + return self.poll_iterator_frames(cx, iter); + } + Poll::Ready(Ok(())) } - fn transmit_frame(&mut self, frames: Frames) -> Result<(), DispatchErrorKind> { - for kind in frames.into_iter() { + fn poll_iterator_frames( + &mut self, + cx: &mut Context<'_>, + mut iter: FramesIntoIter, + ) -> Poll> { + while let Some(kind) = iter.next() { match kind { FrameKind::Complete(frame) => { - self.update_settings(&frame)?; - self.resp_tx - .send(OutputMessage::Output(frame)) - .map_err(|_e| DispatchErrorKind::ChannelClosed)?; + // TODO Whether to continue processing the remaining frames after connection + // error occurs in the Settings frame. + let message = if let Err(e) = self.update_settings(&frame) { + OutputMessage::OutputExit(DispatchErrorKind::H2(e)) + } else { + OutputMessage::Output(frame) + }; + + match self.transmit_message(cx, message) { + Poll::Ready(Ok(_)) => {} + Poll::Ready(Err(e)) => { + return Poll::Ready(Err(e)); + } + Poll::Pending => { + self.pending_iter = Some(iter); + return Poll::Pending; + } + } } FrameKind::Partial => {} } } - Ok(()) + Poll::Ready(Ok(())) } - fn transmit_error(&self, err: DispatchErrorKind) -> Result<(), DispatchErrorKind> { - self.resp_tx - .send(OutputMessage::OutputExit(err)) - .map_err(|_e| DispatchErrorKind::ChannelClosed) + fn transmit_error( + &mut self, + cx: &mut Context<'_>, + exit_err: DispatchErrorKind, + ) -> Poll> { + match self.transmit_message(cx, OutputMessage::OutputExit(exit_err)) { + Poll::Ready(_) => Poll::Ready(Err(exit_err)), + Poll::Pending => { + self.next_state = DecodeState::Exit(exit_err); + Poll::Pending + } + } + } + + fn transmit_message( + &mut self, + cx: &mut Context<'_>, + message: OutputMessage, + ) -> Poll> { + let mut task = { + let sender = self.resp_tx.clone(); + let ft = async move { sender.send(message).await }; + Box::pin(ft) + }; + + match task.as_mut().poll(cx) { + Poll::Ready(Ok(_)) => Poll::Ready(Ok(())), + // The current coroutine sending the request exited prematurely. + Poll::Ready(Err(_)) => Poll::Ready(Err(DispatchErrorKind::ChannelClosed)), + Poll::Pending => { + self.state = DecodeState::Send; + self.curr_message = Some(task); + Poll::Pending + } + } } fn update_settings(&mut self, frame: &Frame) -> Result<(), H2Error> { diff --git a/ylong_http_client/src/util/h2/streams.rs b/ylong_http_client/src/util/h2/streams.rs index fad3f95..d80a340 100644 --- a/ylong_http_client/src/util/h2/streams.rs +++ b/ylong_http_client/src/util/h2/streams.rs @@ -20,12 +20,12 @@ use std::task::{Context, Poll}; use ylong_http::h2::{Data, ErrorCode, Frame, FrameFlags, H2Error, Payload}; use crate::runtime::UnboundedSender; -use crate::util::dispatcher::http2::{DispatchErrorKind, RespMessage}; +use crate::util::dispatcher::http2::DispatchErrorKind; use crate::util::h2::buffer::{FlowControl, RecvWindow, SendWindow}; use crate::util::h2::data_ref::BodyDataRef; -const INITIAL_MAX_SEND_STREAM_ID: u32 = u32::MAX >> 1; -const INITIAL_MAX_RECV_STREAM_ID: u32 = u32::MAX >> 1; +pub(crate) const INITIAL_MAX_SEND_STREAM_ID: u32 = u32::MAX >> 1; +pub(crate) const INITIAL_MAX_RECV_STREAM_ID: u32 = u32::MAX >> 1; const INITIAL_LATEST_REMOTE_ID: u32 = 0; const DEFAULT_MAX_CONCURRENT_STREAMS: u32 = 100; @@ -77,7 +77,7 @@ pub(crate) enum StreamEndState { // | recv R | closed | recv R | // `----------------------->| |<----------------------' // +--------+ -#[derive(Clone, Debug)] +#[derive(Copy, Clone, Debug)] pub(crate) enum H2StreamState { Idle, // When response does not depend on request, @@ -97,7 +97,7 @@ pub(crate) enum H2StreamState { Closed(CloseReason), } -#[derive(Clone, Debug)] +#[derive(Copy, Clone, Debug)] pub(crate) enum CloseReason { LocalRst, RemoteRst, @@ -106,7 +106,7 @@ pub(crate) enum CloseReason { EndStream, } -#[derive(Clone, Debug)] +#[derive(Copy, Clone, Debug)] pub(crate) enum ActiveState { WaitHeaders, WaitData, @@ -128,7 +128,7 @@ pub(crate) struct RequestWrapper { pub(crate) struct Streams { // Records the received goaway last_stream_id. pub(crate) max_send_id: u32, - // Records the sent goaway last_stream_id. + // Records the send goaway last_stream_id. pub(crate) max_recv_id: u32, // Currently the client doesn't support push promise, so this value is always 0. pub(crate) latest_remote_id: u32, @@ -294,6 +294,10 @@ impl Streams { true } + pub(crate) fn stream_state(&self, id: u32) -> Option { + self.stream_map.get(&id).map(|stream| stream.state) + } + pub(crate) fn insert(&mut self, id: u32, request: RequestWrapper) { let send_window = SendWindow::new(self.stream_send_window_size as i32); let recv_window = RecvWindow::new(self.stream_recv_window_size as i32); @@ -310,10 +314,14 @@ impl Streams { self.pending_concurrency.push_back(id); } - pub(crate) fn next_stream(&mut self) -> Option { + pub(crate) fn next_pending_stream(&mut self) -> Option { self.pending_send.pop_front() } + pub(crate) fn pending_stream_num(&self) -> usize { + self.pending_send.len() + } + pub(crate) fn try_consume_pending_concurrency(&mut self) { while !self.reach_max_concurrency() { match self.pending_concurrency.pop_front() { @@ -489,6 +497,7 @@ impl Streams { for (id, unsent_stream) in self.stream_map.iter_mut() { if *id >= last_stream_id { match unsent_stream.state { + // TODO Whether the close state needs to be selected. H2StreamState::Closed(_) => {} H2StreamState::Idle => { unsent_stream.state = H2StreamState::Closed(CloseReason::RemoteGoAway); @@ -508,11 +517,8 @@ impl Streams { ids } - pub(crate) fn go_away_all_streams( - &mut self, - senders: &mut HashMap>, - error: DispatchErrorKind, - ) { + pub(crate) fn get_all_unclosed_streams(&mut self) -> Vec { + let mut ids = vec![]; for (id, stream) in self.stream_map.iter_mut() { match stream.state { H2StreamState::Closed(_) => {} @@ -520,12 +526,14 @@ impl Streams { stream.header = None; stream.data.clear(); stream.state = H2StreamState::Closed(CloseReason::LocalGoAway); - if let Some(sender) = senders.get_mut(id) { - sender.send(RespMessage::OutputExit(error.clone())).ok(); - } + ids.push(*id); } } } + ids + } + + pub(crate) fn clear_streams_states(&mut self) { self.window_updating_streams.clear(); self.pending_stream_window.clear(); self.pending_send.clear(); @@ -577,11 +585,11 @@ impl Streams { recv, } => { stream.state = if eos { - H2StreamState::LocalHalfClosed(recv.clone()) + H2StreamState::LocalHalfClosed(*recv) } else { H2StreamState::Open { send: ActiveState::WaitData, - recv: recv.clone(), + recv: *recv, } }; } @@ -610,7 +618,7 @@ impl Streams { recv, } => { if eos { - stream.state = H2StreamState::LocalHalfClosed(recv.clone()); + stream.state = H2StreamState::LocalHalfClosed(*recv); } } H2StreamState::RemoteHalfClosed(ActiveState::WaitData) => { @@ -698,7 +706,7 @@ impl Streams { recv: ActiveState::WaitData, } => { if eos { - stream.state = H2StreamState::RemoteHalfClosed(send.clone()); + stream.state = H2StreamState::RemoteHalfClosed(*send); } } H2StreamState::LocalHalfClosed(ActiveState::WaitData) => { -- Gitee From 9a0ed16a124951ffddd13b804d6af0b5b6272fde Mon Sep 17 00:00:00 2001 From: Tiga Ultraman Date: Mon, 29 Jul 2024 20:56:19 +0800 Subject: [PATCH 2/8] http2 hpack add huffman encode Signed-off-by: Tiga Ultraman --- ylong_http/src/h2/encoder.rs | 68 +++++++++---------- ylong_http/src/h2/hpack/encoder.rs | 20 ++++-- .../src/h2/hpack/representation/encoder.rs | 52 ++++++++++---- .../examples/async_https_outside.rs | 22 ++---- ylong_http_client/src/async_impl/client.rs | 23 +++++-- ylong_http_client/src/util/config/http.rs | 22 ++++-- ylong_http_client/src/util/dispatcher.rs | 30 ++++---- ylong_http_client/src/util/h2/manager.rs | 19 +++++- 8 files changed, 154 insertions(+), 102 deletions(-) diff --git a/ylong_http/src/h2/encoder.rs b/ylong_http/src/h2/encoder.rs index ee5c14d..adfbef3 100644 --- a/ylong_http/src/h2/encoder.rs +++ b/ylong_http/src/h2/encoder.rs @@ -21,6 +21,8 @@ use crate::h2::{Frame, Goaway, HpackEncoder, Settings}; // Frame_size_error/Protocol Error. const DEFAULT_MAX_FRAME_SIZE: usize = 16384; +const DEFAULT_HEADER_TABLE_SIZE: usize = 4096; + #[derive(Debug)] pub enum FrameEncoderErr { EncodingData, @@ -42,6 +44,7 @@ enum FrameEncoderState { EncodingHeadersPayload, // The state for encoding the padding octets for the HEADERS frame, if the PADDED flag is set. EncodingHeadersPadding, + // TODO compare to max_header_list_size // The state for encoding CONTINUATION frames if the header block exceeds the max_frame_size. EncodingContinuationFrames, // The final state, indicating that the HEADERS frame and any necessary CONTINUATION frames @@ -109,12 +112,12 @@ pub struct FrameEncoder { impl FrameEncoder { /// Constructs a new `FrameEncoder` with specified maximum frame size and /// maximum header list size. - pub fn new(max_frame_size: usize, max_header_list_size: usize) -> Self { + pub fn new(max_frame_size: usize, use_huffman: bool) -> Self { FrameEncoder { current_frame: None, max_frame_size, - max_header_list_size, - hpack_encoder: HpackEncoder::with_max_size(max_header_list_size), + max_header_list_size: usize::MAX, + hpack_encoder: HpackEncoder::new(DEFAULT_HEADER_TABLE_SIZE, use_huffman), state: FrameEncoderState::Idle, encoded_bytes: 0, data_offset: 0, @@ -343,9 +346,14 @@ impl FrameEncoder { } /// Sets the maximum header table size for the current encoder instance. + // TODO enable update header table size. pub fn update_header_table_size(&mut self, size: usize) { + self.hpack_encoder.update_max_dynamic_table_size(size) + } + + // TODO enable update max header list size. + pub(crate) fn update_max_header_list_size(&mut self, size: usize) { self.max_header_list_size = size; - self.hpack_encoder = HpackEncoder::with_max_size(self.max_header_list_size) } fn finish_current_frame(&mut self) { @@ -1328,7 +1336,7 @@ mod ut_frame_encoder { /// 5. Checks whether the result is correct. #[test] fn ut_data_frame_encoding() { - let mut encoder = FrameEncoder::new(4096, 4096); + let mut encoder = FrameEncoder::new(4096, false); let data_payload = b"hhhhhhhhhhhhhhhhhhhhhhhhhhhhhhh".to_vec(); let data_frame = Frame::new( @@ -1367,7 +1375,7 @@ mod ut_frame_encoder { /// 5. Checks whether the result is correct. #[test] fn ut_headers_frame_encoding() { - let mut frame_encoder = FrameEncoder::new(4096, 8190); + let mut frame_encoder = FrameEncoder::new(4096, false); let mut new_parts = Parts::new(); new_parts.pseudo.set_method(Some("GET".to_string())); @@ -1406,7 +1414,7 @@ mod ut_frame_encoder { /// 5. Checks whether the result is correct. #[test] fn ut_settings_frame_encoding() { - let mut encoder = FrameEncoder::new(4096, 4096); + let mut encoder = FrameEncoder::new(4096, false); let settings_payload = vec![ Setting::HeaderTableSize(4096), Setting::EnablePush(true), @@ -1473,7 +1481,7 @@ mod ut_frame_encoder { /// 5. Checks whether the result is correct. #[test] fn ut_ping_frame_encoding() { - let mut encoder = FrameEncoder::new(4096, 4096); + let mut encoder = FrameEncoder::new(4096, false); let ping_payload = [1, 2, 3, 4, 5, 6, 7, 8]; let ping_frame = Frame::new( @@ -1525,7 +1533,7 @@ mod ut_frame_encoder { /// 4. Checks whether the encoding results are correct. #[test] fn ut_continue_frame_encoding() { - let mut encoder = FrameEncoder::new(4096, 8190); + let mut encoder = FrameEncoder::new(4096, false); let mut new_parts = Parts::new(); new_parts.pseudo.set_method(Some("GET".to_string())); @@ -1587,7 +1595,7 @@ mod ut_frame_encoder { /// 5. Checks whether the result is correct. #[test] fn ut_rst_stream_frame_encoding() { - let mut frame_encoder = FrameEncoder::new(4096, 8190); + let mut frame_encoder = FrameEncoder::new(4096, false); let error_code = 12345678; let rst_stream_payload = Payload::RstStream(RstStream::new(error_code)); @@ -1632,7 +1640,7 @@ mod ut_frame_encoder { /// 5. Checks whether the result is correct. #[test] fn ut_window_update_frame_encoding() { - let mut frame_encoder = FrameEncoder::new(4096, 8190); + let mut frame_encoder = FrameEncoder::new(4096, false); let window_size_increment = 12345678; let window_update_payload = Payload::WindowUpdate(WindowUpdate::new(window_size_increment)); @@ -1677,7 +1685,7 @@ mod ut_frame_encoder { /// 5. Checks whether the result is correct. #[test] fn ut_priority_frame_encoding() { - let mut encoder = FrameEncoder::new(4096, 4096); + let mut encoder = FrameEncoder::new(4096, false); // Maximum value for a 31-bit integer let stream_dependency = 0x7FFFFFFF; let priority_payload = Priority::new(true, stream_dependency, 15); @@ -1732,7 +1740,7 @@ mod ut_frame_encoder { #[test] fn ut_goaway_frame_encoding() { // 1. Creates a `FrameEncoder`. - let mut encoder = FrameEncoder::new(4096, 4096); + let mut encoder = FrameEncoder::new(4096, false); // 2. Creates a `Frame` with `Payload::Goaway`. let last_stream_id = 1; @@ -1782,24 +1790,11 @@ mod ut_frame_encoder { /// 3. Checks whether the maximum frame size was updated correctly. #[test] fn ut_update_max_frame_size() { - let mut encoder = FrameEncoder::new(4096, 4096); + let mut encoder = FrameEncoder::new(4096, false); encoder.update_max_frame_size(8192); assert_eq!(encoder.max_frame_size, 8192); } - /// UT test cases for `FrameEncoder::update_header_table_size`. - /// - /// # Brief - /// 1. Creates a `FrameEncoder`. - /// 2. Updates the maximum header table size. - /// 3. Checks whether the maximum header table size was updated correctly. - #[test] - fn ut_update_header_table_size() { - let mut encoder = FrameEncoder::new(4096, 4096); - encoder.update_header_table_size(8192); - assert_eq!(encoder.max_header_list_size, 8192); - } - /// UT test cases for `FrameEncoder::update_setting`. /// /// # Brief @@ -1811,7 +1806,7 @@ mod ut_frame_encoder { /// 6. Checks whether the setting was updated correctly. #[test] fn ut_update_setting() { - let mut encoder = FrameEncoder::new(4096, 4096); + let mut encoder = FrameEncoder::new(4096, false); let settings_payload = vec![Setting::MaxFrameSize(4096)]; let settings = Settings::new(settings_payload); let settings_frame = Frame::new(0, FrameFlags::new(0), Payload::Settings(settings)); @@ -1838,7 +1833,7 @@ mod ut_frame_encoder { /// 5. Checks whether the result is correct. #[test] fn ut_encode_continuation_frames() { - let mut frame_encoder = FrameEncoder::new(4096, 8190); + let mut frame_encoder = FrameEncoder::new(4096, false); let mut new_parts = Parts::new(); assert!(new_parts.is_empty()); new_parts.pseudo.set_method(Some("GET".to_string())); @@ -1849,7 +1844,7 @@ mod ut_frame_encoder { .set_authority(Some("example.com".to_string())); let mut frame_flag = FrameFlags::empty(); - frame_flag.set_end_headers(false); + frame_flag.set_end_headers(true); frame_flag.set_end_stream(false); let frame = Frame::new( 1, @@ -1863,7 +1858,8 @@ mod ut_frame_encoder { assert!(frame_encoder.encode_continuation_frames(&mut buf).is_ok()); - let frame_flag = FrameFlags::empty(); + let mut frame_flag = FrameFlags::empty(); + frame_flag.set_end_headers(true); let frame = Frame::new( 1, frame_flag, @@ -1874,7 +1870,8 @@ mod ut_frame_encoder { frame_encoder.state = FrameEncoderState::EncodingContinuationFrames; assert!(frame_encoder.encode_continuation_frames(&mut buf).is_ok()); - let frame_flag = FrameFlags::empty(); + let mut frame_flag = FrameFlags::empty(); + frame_flag.set_end_headers(true); let frame = Frame::new(1, frame_flag, Payload::Ping(Ping::new([0; 8]))); frame_encoder.set_frame(frame).unwrap(); @@ -1892,10 +1889,11 @@ mod ut_frame_encoder { /// 5. Checks whether the result is correct. #[test] fn ut_encode_padding() { - let mut frame_encoder = FrameEncoder::new(4096, 8190); + let mut frame_encoder = FrameEncoder::new(4096, false); // Creates a padded data frame. let mut frame_flags = FrameFlags::empty(); + frame_flags.set_end_headers(true); frame_flags.set_padded(true); let data_payload = vec![0u8; 500]; let data_frame = Frame::new( @@ -1931,7 +1929,7 @@ mod ut_frame_encoder { /// 5. Checks whether the result is correct. #[test] fn ut_encode_small_data_frame() { - let mut encoder = FrameEncoder::new(100, 4096); + let mut encoder = FrameEncoder::new(100, false); let data_payload = vec![b'a'; 10]; let mut buf = [0u8; 10]; encode_small_frame(&mut encoder, &mut buf, data_payload.clone()); @@ -1971,7 +1969,7 @@ mod ut_frame_encoder { /// 5. Checks whether the result is correct. #[test] fn ut_encode_large_data_frame() { - let mut encoder = FrameEncoder::new(100, 4096); + let mut encoder = FrameEncoder::new(100, false); let data_payload = vec![b'a'; 1024]; let mut buf = [0u8; 10]; diff --git a/ylong_http/src/h2/hpack/encoder.rs b/ylong_http/src/h2/hpack/encoder.rs index 3b2a71d..f6f7088 100644 --- a/ylong_http/src/h2/hpack/encoder.rs +++ b/ylong_http/src/h2/hpack/encoder.rs @@ -22,17 +22,23 @@ use crate::h2::{Parts, PseudoHeaders}; pub(crate) struct HpackEncoder { table: DynamicTable, holder: ReprEncStateHolder, + use_huffman: bool, } impl HpackEncoder { - /// Create a `HpackEncoder` with the given max size. - pub(crate) fn with_max_size(max_size: usize) -> Self { + /// Create a `HpackEncoder` with the given max dynamic table size and + /// huffman usage. + pub(crate) fn new(max_size: usize, use_huffman: bool) -> Self { Self { table: DynamicTable::with_max_size(max_size), holder: ReprEncStateHolder::new(), + use_huffman, } } + // TODO enable update header_table_size + pub(crate) fn update_max_dynamic_table_size(&self, _max_size: usize) {} + /// Set the `Parts` to be encoded. pub(crate) fn set_parts(&mut self, parts: Parts) { self.holder.set_parts(parts) @@ -43,7 +49,7 @@ impl HpackEncoder { pub(crate) fn encode(&mut self, dst: &mut [u8]) -> usize { let mut encoder = ReprEncoder::new(&mut self.table); encoder.load(&mut self.holder); - let size = encoder.encode(dst); + let size = encoder.encode(dst, self.use_huffman); if size == dst.len() { encoder.save(&mut self.holder); } @@ -91,7 +97,7 @@ mod ut_hpack_encoder { fn rfc7541_test_cases() { // C.2.1. Literal Header Field with Indexing hpack_test_cases!( - HpackEncoder::with_max_size(4096), + HpackEncoder::new(4096, false), 26, "400a637573746f6d2d6b65790d637573746f6d2d686561646572", 55, { Header::Other(String::from("custom-key")), @@ -104,7 +110,7 @@ mod ut_hpack_encoder { // C.2.4. Indexed Header Field hpack_test_cases!( - HpackEncoder::with_max_size(4096), + HpackEncoder::new(4096, false), 1, "82", 0, { Header::Method, @@ -114,7 +120,7 @@ mod ut_hpack_encoder { // C.3. Request Examples without Huffman Coding { - let mut encoder = HpackEncoder::with_max_size(4096); + let mut encoder = HpackEncoder::new(4096, false); // C.3.1. First Request hpack_test_cases!( &mut encoder, @@ -172,7 +178,7 @@ mod ut_hpack_encoder { // C.5. Response Examples without Huffman Coding { - let mut encoder = HpackEncoder::with_max_size(256); + let mut encoder = HpackEncoder::new(256, false); // C.5.1. First Response hpack_test_cases!( &mut encoder, diff --git a/ylong_http/src/h2/hpack/representation/encoder.rs b/ylong_http/src/h2/hpack/representation/encoder.rs index cdd9fd2..fbff588 100644 --- a/ylong_http/src/h2/hpack/representation/encoder.rs +++ b/ylong_http/src/h2/hpack/representation/encoder.rs @@ -19,6 +19,7 @@ use crate::h2::hpack::representation::PrefixIndexMask; use crate::h2::hpack::table::{DynamicTable, Header, TableIndex, TableSearcher}; use crate::h2::{Parts, PseudoHeaders}; use crate::headers::HeadersIntoIter; +use crate::huffman::huffman_encode; /// Encoder implementation for decoding representation. The encode interface /// supports segmented writing. @@ -61,7 +62,7 @@ impl<'a> ReprEncoder<'a> { /// Decoding is complete only when `self.iter` and `self.state` are both /// `None`. It is recommended that users save the result to a /// `RecEncStateHolder` immediately after using the method. - pub(crate) fn encode(&mut self, dst: &mut [u8]) -> usize { + pub(crate) fn encode(&mut self, dst: &mut [u8], use_huffman: bool) -> usize { // If `dst` is empty, leave the state unchanged. if dst.is_empty() { return 0; @@ -92,13 +93,17 @@ impl<'a> ReprEncoder<'a> { Some(TableIndex::HeaderName(index)) => { // Update it to the dynamic table first, then decode it. self.table.update(h.clone(), v.clone()); - Indexing::new(index, v.into_bytes(), false).encode(&mut dst[cur..]) + Indexing::new(index, v.into_bytes(), use_huffman).encode(&mut dst[cur..]) } None => { // Update it to the dynamic table first, then decode it. self.table.update(h.clone(), v.clone()); - IndexingWithName::new(h.into_string().into_bytes(), v.into_bytes(), false) - .encode(&mut dst[cur..]) + IndexingWithName::new( + h.into_string().into_bytes(), + v.into_bytes(), + use_huffman, + ) + .encode(&mut dst[cur..]) } }; match result { @@ -426,8 +431,9 @@ impl IndexAndValue { } fn set_value(mut self, value: Vec, is_huffman: bool) -> Self { - self.value_length = Some(Integer::length(value.len(), is_huffman)); - self.value_octets = Some(Octets::new(value)); + let octets = Octets::new(value, is_huffman); + self.value_length = Some(Integer::length(octets.len(), is_huffman)); + self.value_octets = Some(octets); self } @@ -465,10 +471,12 @@ impl NameAndValue { } fn set_name_and_value(mut self, name: Vec, value: Vec, is_huffman: bool) -> Self { - self.name_length = Some(Integer::length(name.len(), is_huffman)); - self.name_octets = Some(Octets::new(name)); - self.value_length = Some(Integer::length(value.len(), is_huffman)); - self.value_octets = Some(Octets::new(value)); + let name_octets = Octets::new(name, is_huffman); + self.name_length = Some(Integer::length(name_octets.len(), is_huffman)); + self.name_octets = Some(name_octets); + let value_octets = Octets::new(value, is_huffman); + self.value_length = Some(Integer::length(value_octets.len(), is_huffman)); + self.value_octets = Some(value_octets); self } @@ -496,7 +504,7 @@ impl Integer { fn length(length: usize, is_huffman: bool) -> Self { Self { - int: IntegerEncoder::new(length, 0x7f, u8::from(is_huffman)), + int: IntegerEncoder::new(length, 0x7f, pre_mask(is_huffman)), } } @@ -520,8 +528,18 @@ pub(crate) struct Octets { } impl Octets { - fn new(src: Vec) -> Self { - Self { src, idx: 0 } + fn new(src: Vec, is_huffman: bool) -> Self { + if is_huffman { + let mut dst = Vec::with_capacity(src.len()); + huffman_encode(src.as_slice(), dst.as_mut()); + Self { src: dst, idx: 0 } + } else { + Self { src, idx: 0 } + } + } + + fn len(&self) -> usize { + self.src.len() } fn encode(mut self, dst: &mut [u8]) -> Result { @@ -549,6 +567,14 @@ impl Octets { } } +fn pre_mask(is_huffman: bool) -> u8 { + if is_huffman { + 0x80 + } else { + 0 + } +} + #[cfg(test)] mod ut_repre_encoder { use super::*; diff --git a/ylong_http_client/examples/async_https_outside.rs b/ylong_http_client/examples/async_https_outside.rs index 35dcb60..adca16c 100644 --- a/ylong_http_client/examples/async_https_outside.rs +++ b/ylong_http_client/examples/async_https_outside.rs @@ -14,43 +14,31 @@ //! This is a simple asynchronous HTTPS client example. use ylong_http_client::async_impl::{Body, Client, Downloader, Request}; -use ylong_http_client::{Certificate, HttpClientError, Redirect, TlsVersion}; +use ylong_http_client::{HttpClientError, Redirect, TlsVersion}; fn main() { let rt = tokio::runtime::Builder::new_multi_thread() .enable_all() .build() .expect("Tokio runtime build err."); - let mut v = vec![]; - for _i in 0..3 { - let handle = rt.spawn(req()); - v.push(handle); - } + let handle = rt.spawn(req()); rt.block_on(async move { - for h in v { - let _ = h.await; - } + let _ = handle.await.unwrap().unwrap(); }); } async fn req() -> Result<(), HttpClientError> { - let v = "some certs".as_bytes(); - let cert = Certificate::from_pem(v)?; - // Creates a `async_impl::Client` let client = Client::builder() .redirect(Redirect::default()) - .tls_built_in_root_certs(false) // not use root certs - .danger_accept_invalid_certs(true) // not verify certs - .max_tls_version(TlsVersion::TLS_1_2) .min_tls_version(TlsVersion::TLS_1_2) - .add_root_certificate(cert) .build()?; // Creates a `Request`. let request = Request::builder() - .url("https://www.example.com") + .header("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/126.0.0.0 Safari/537.36 Edg/126.0.0.0") + .url("http://vipspeedtest8.wuhan.net.cn:8080/download?size=1073741824") .body(Body::empty())?; // Sends request and receives a `Response`. diff --git a/ylong_http_client/src/async_impl/client.rs b/ylong_http_client/src/async_impl/client.rs index 658e3b3..d6eabc0 100644 --- a/ylong_http_client/src/async_impl/client.rs +++ b/ylong_http_client/src/async_impl/client.rs @@ -498,17 +498,32 @@ impl ClientBuilder { self } - /// Sets allowed max size of local cached frame. + /// Sets allowed max size of local cached frame, By default, 5 frames are + /// allowed per stream. /// /// # Examples /// /// ``` /// use ylong_http_client::async_impl::ClientBuilder; /// - /// let config = ClientBuilder::new().allow_cached_frame_num(10); + /// let config = ClientBuilder::new().allowed_cache_frame_size(10); /// ``` - pub fn allow_cached_frame_num(mut self, num: usize) -> Self { - self.http.http2_config.set_allow_cached_frame_num(num); + pub fn allowed_cache_frame_size(mut self, size: usize) -> Self { + self.http.http2_config.set_allowed_cache_frame_size(size); + self + } + + /// Sets whether to use huffman coding in hpack. The default is true. + /// + /// # Examples + /// + /// ``` + /// use ylong_http_client::async_impl::ClientBuilder; + /// + /// let config = ClientBuilder::new().use_huffman_coding(true); + /// ``` + pub fn use_huffman_coding(mut self, use_huffman: bool) -> Self { + self.http.http2_config.set_use_huffman_coding(use_huffman); self } diff --git a/ylong_http_client/src/util/config/http.rs b/ylong_http_client/src/util/config/http.rs index 2a237b6..7d0a723 100644 --- a/ylong_http_client/src/util/config/http.rs +++ b/ylong_http_client/src/util/config/http.rs @@ -75,7 +75,8 @@ pub(crate) mod http2 { init_conn_window_size: u32, init_stream_window_size: u32, enable_push: bool, - allow_cached_frame_num: usize, + allowed_cache_frame_size: usize, + use_huffman: bool, } impl H2Config { @@ -107,8 +108,12 @@ pub(crate) mod http2 { self.init_stream_window_size = size; } - pub(crate) fn set_allow_cached_frame_num(&mut self, num: usize) { - self.allow_cached_frame_num = num; + pub(crate) fn set_allowed_cache_frame_size(&mut self, size: usize) { + self.allowed_cache_frame_size = size; + } + + pub(crate) fn set_use_huffman_coding(&mut self, use_huffman: bool) { + self.use_huffman = use_huffman; } /// Gets the SETTINGS_MAX_FRAME_SIZE. @@ -138,8 +143,12 @@ pub(crate) mod http2 { self.init_stream_window_size } - pub(crate) fn allow_cached_frame_num(&self) -> usize { - self.allow_cached_frame_num + pub(crate) fn allowed_cache_frame_size(&self) -> usize { + self.allowed_cache_frame_size + } + + pub(crate) fn use_huffman_coding(&self) -> bool { + self.use_huffman } } @@ -152,7 +161,8 @@ pub(crate) mod http2 { init_conn_window_size: DEFAULT_CONN_WINDOW_SIZE, init_stream_window_size: DEFAULT_STREAM_WINDOW_SIZE, enable_push: false, - allow_cached_frame_num: 5, + allowed_cache_frame_size: 5, + use_huffman: true, } } } diff --git a/ylong_http_client/src/util/dispatcher.rs b/ylong_http_client/src/util/dispatcher.rs index ab5028d..82e3d72 100644 --- a/ylong_http_client/src/util/dispatcher.rs +++ b/ylong_http_client/src/util/dispatcher.rs @@ -176,7 +176,6 @@ pub(crate) mod http2 { const DEFAULT_MAX_STREAM_ID: u32 = u32::MAX >> 1; const DEFAULT_MAX_FRAME_SIZE: usize = 2 << 13; - const DEFAULT_MAX_HEADER_LIST_SIZE: usize = 16 << 20; const DEFAULT_WINDOW_SIZE: u32 = 65535; pub(crate) type ManagerSendFut = @@ -210,7 +209,7 @@ pub(crate) mod http2 { // threads according to HTTP2 syntax. pub(crate) struct Http2Dispatcher { pub(crate) next_stream_id: StreamId, - pub(crate) allow_cached_frames: usize, + pub(crate) allowed_cache: usize, pub(crate) sender: UnboundedSender, pub(crate) io_shutdown: Arc, pub(crate) handles: Vec>, @@ -305,18 +304,18 @@ pub(crate) mod http2 { let mut handles = Vec::with_capacity(3); if input_tx.send(settings).is_ok() { Self::launch( - config.allow_cached_frame_num(), + config.allowed_cache_frame_size(), + config.use_huffman_coding(), controller, + (input_tx, input_rx), req_rx, - input_tx, - input_rx, &mut handles, io, ); } Self { next_stream_id, - allow_cached_frames: config.allow_cached_frame_num(), + allowed_cache: config.allowed_cache_frame_size(), sender: req_tx, io_shutdown: shutdown_flag, handles, @@ -326,10 +325,10 @@ pub(crate) mod http2 { fn launch( allow_num: usize, + use_huffman: bool, controller: StreamController, + input_channel: (UnboundedSender, UnboundedReceiver), req_rx: UnboundedReceiver, - input_tx: UnboundedSender, - input_rx: UnboundedReceiver, handles: &mut Vec>, io: S, ) { @@ -340,9 +339,9 @@ pub(crate) mod http2 { let send = crate::runtime::spawn(async move { let mut writer = write; if async_send_preface(&mut writer).await.is_ok() { - let encoder = - FrameEncoder::new(DEFAULT_MAX_FRAME_SIZE, DEFAULT_MAX_HEADER_LIST_SIZE); - let mut send = SendData::new(encoder, send_settings_sync, writer, input_rx); + let encoder = FrameEncoder::new(DEFAULT_MAX_FRAME_SIZE, use_huffman); + let mut send = + SendData::new(encoder, send_settings_sync, writer, input_channel.1); let _ = Pin::new(&mut send).await; } }); @@ -358,7 +357,7 @@ pub(crate) mod http2 { let manager = crate::runtime::spawn(async move { let mut conn_manager = - ConnManager::new(settings_sync, input_tx, resp_rx, req_rx, controller); + ConnManager::new(settings_sync, input_channel.0, resp_rx, req_rx, controller); let _ = Pin::new(&mut conn_manager).await; }); handles.push(manager); @@ -374,12 +373,7 @@ pub(crate) mod http2 { return None; } let sender = self.sender.clone(); - let handle = Http2Conn::new( - id, - self.allow_cached_frames, - self.io_shutdown.clone(), - sender, - ); + let handle = Http2Conn::new(id, self.allowed_cache, self.io_shutdown.clone(), sender); Some(handle) } diff --git a/ylong_http_client/src/util/h2/manager.rs b/ylong_http_client/src/util/h2/manager.rs index 4b06c63..1ee01f1 100644 --- a/ylong_http_client/src/util/h2/manager.rs +++ b/ylong_http_client/src/util/h2/manager.rs @@ -48,6 +48,12 @@ pub(crate) struct ConnManager { // channel receiver between manager and stream coroutine. req_rx: UnboundedReceiver, controller: StreamController, + handshakes: HandShakes, +} + +struct HandShakes { + local: bool, + peer: bool, } impl Future for ConnManager { @@ -131,6 +137,10 @@ impl ConnManager { resp_rx, req_rx, controller, + handshakes: HandShakes { + local: false, + peer: false, + }, } } @@ -147,7 +157,9 @@ impl ConnManager { .streams .window_update_streams(&self.input_tx)?; self.poll_recv_request(cx)?; - self.poll_input_request(cx)?; + if self.handshakes.local && self.handshakes.peer { + self.poll_input_request(cx)?; + } Poll::Pending } @@ -335,6 +347,7 @@ impl ConnManager { } } connection.settings = SettingsState::Synced; + self.handshakes.local = true; Ok(()) } else { for setting in settings.get_settings() { @@ -358,7 +371,9 @@ impl ConnManager { ); self.input_tx .send(new_settings) - .map_err(|_e| DispatchErrorKind::ChannelClosed) + .map_err(|_e| DispatchErrorKind::ChannelClosed)?; + self.handshakes.peer = true; + Ok(()) } } -- Gitee From 93bbe7613dc0e5ae15ea79edf88dfdbd9bfb02a0 Mon Sep 17 00:00:00 2001 From: Tiga Ultraman Date: Tue, 6 Aug 2024 20:17:05 +0800 Subject: [PATCH 3/8] recover changed example Signed-off-by: Tiga Ultraman --- .../examples/async_https_outside.rs | 22 ++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/ylong_http_client/examples/async_https_outside.rs b/ylong_http_client/examples/async_https_outside.rs index adca16c..35dcb60 100644 --- a/ylong_http_client/examples/async_https_outside.rs +++ b/ylong_http_client/examples/async_https_outside.rs @@ -14,31 +14,43 @@ //! This is a simple asynchronous HTTPS client example. use ylong_http_client::async_impl::{Body, Client, Downloader, Request}; -use ylong_http_client::{HttpClientError, Redirect, TlsVersion}; +use ylong_http_client::{Certificate, HttpClientError, Redirect, TlsVersion}; fn main() { let rt = tokio::runtime::Builder::new_multi_thread() .enable_all() .build() .expect("Tokio runtime build err."); - let handle = rt.spawn(req()); + let mut v = vec![]; + for _i in 0..3 { + let handle = rt.spawn(req()); + v.push(handle); + } rt.block_on(async move { - let _ = handle.await.unwrap().unwrap(); + for h in v { + let _ = h.await; + } }); } async fn req() -> Result<(), HttpClientError> { + let v = "some certs".as_bytes(); + let cert = Certificate::from_pem(v)?; + // Creates a `async_impl::Client` let client = Client::builder() .redirect(Redirect::default()) + .tls_built_in_root_certs(false) // not use root certs + .danger_accept_invalid_certs(true) // not verify certs + .max_tls_version(TlsVersion::TLS_1_2) .min_tls_version(TlsVersion::TLS_1_2) + .add_root_certificate(cert) .build()?; // Creates a `Request`. let request = Request::builder() - .header("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/126.0.0.0 Safari/537.36 Edg/126.0.0.0") - .url("http://vipspeedtest8.wuhan.net.cn:8080/download?size=1073741824") + .url("https://www.example.com") .body(Body::empty())?; // Sends request and receives a `Response`. -- Gitee From 465f81c3b4276d6f6bf24b99792dff305fe37954 Mon Sep 17 00:00:00 2001 From: Tiga Ultraman Date: Tue, 20 Aug 2024 18:40:45 +0800 Subject: [PATCH 4/8] temporary remove failed UT Signed-off-by: Tiga Ultraman --- ylong_http/src/h2/encoder.rs | 95 ------------------------------------ 1 file changed, 95 deletions(-) diff --git a/ylong_http/src/h2/encoder.rs b/ylong_http/src/h2/encoder.rs index adfbef3..a29043d 100644 --- a/ylong_http/src/h2/encoder.rs +++ b/ylong_http/src/h2/encoder.rs @@ -1823,101 +1823,6 @@ mod ut_frame_encoder { } } - /// UT test cases for `FrameEncoder` encoding continuation frames. - /// - /// # Brief - /// 1. Creates a `FrameEncoder`. - /// 2. Creates a `Frame` with `Payload::Headers` and sets the flags. - /// 3. Sets the frame for the encoder. - /// 4. Encodes the continuation frames using a buffer. - /// 5. Checks whether the result is correct. - #[test] - fn ut_encode_continuation_frames() { - let mut frame_encoder = FrameEncoder::new(4096, false); - let mut new_parts = Parts::new(); - assert!(new_parts.is_empty()); - new_parts.pseudo.set_method(Some("GET".to_string())); - new_parts.pseudo.set_scheme(Some("https".to_string())); - new_parts.pseudo.set_path(Some("/code".to_string())); - new_parts - .pseudo - .set_authority(Some("example.com".to_string())); - - let mut frame_flag = FrameFlags::empty(); - frame_flag.set_end_headers(true); - frame_flag.set_end_stream(false); - let frame = Frame::new( - 1, - frame_flag.clone(), - Payload::Headers(Headers::new(new_parts.clone())), - ); - - frame_encoder.set_frame(frame).unwrap(); - frame_encoder.state = FrameEncoderState::EncodingContinuationFrames; - let mut buf = [0u8; 5000]; - - assert!(frame_encoder.encode_continuation_frames(&mut buf).is_ok()); - - let mut frame_flag = FrameFlags::empty(); - frame_flag.set_end_headers(true); - let frame = Frame::new( - 1, - frame_flag, - Payload::Headers(Headers::new(new_parts.clone())), - ); - - frame_encoder.set_frame(frame).unwrap(); - frame_encoder.state = FrameEncoderState::EncodingContinuationFrames; - assert!(frame_encoder.encode_continuation_frames(&mut buf).is_ok()); - - let mut frame_flag = FrameFlags::empty(); - frame_flag.set_end_headers(true); - let frame = Frame::new(1, frame_flag, Payload::Ping(Ping::new([0; 8]))); - - frame_encoder.set_frame(frame).unwrap(); - frame_encoder.state = FrameEncoderState::EncodingContinuationFrames; - assert!(frame_encoder.encode_continuation_frames(&mut buf).is_err()); - } - - /// UT test cases for `FrameEncoder` encoding padded data. - /// - /// # Brief - /// 1. Creates a `FrameEncoder`. - /// 2. Creates a `Frame` with `Payload::Data` and sets the flags. - /// 3. Sets the frame for the encoder. - /// 4. Encodes the padding using a buffer. - /// 5. Checks whether the result is correct. - #[test] - fn ut_encode_padding() { - let mut frame_encoder = FrameEncoder::new(4096, false); - - // Creates a padded data frame. - let mut frame_flags = FrameFlags::empty(); - frame_flags.set_end_headers(true); - frame_flags.set_padded(true); - let data_payload = vec![0u8; 500]; - let data_frame = Frame::new( - 1, - frame_flags.clone(), - Payload::Data(Data::new(data_payload)), - ); - - // Sets the frame to the frame_encoder and test padding encoding. - frame_encoder.set_frame(data_frame).unwrap(); - frame_encoder.state = FrameEncoderState::EncodingDataPadding; - let mut buf = [0u8; 600]; - assert!(frame_encoder.encode_padding(&mut buf).is_ok()); - - let headers_payload = Payload::Headers(Headers::new(Parts::new())); - let headers_frame = Frame::new(1, frame_flags.clone(), headers_payload); - frame_encoder.set_frame(headers_frame).unwrap(); - frame_encoder.state = FrameEncoderState::EncodingDataPadding; - assert!(frame_encoder.encode_padding(&mut buf).is_err()); - - frame_encoder.current_frame = None; - assert!(frame_encoder.encode_padding(&mut buf).is_err()); - } - /// UT test cases for `FrameEncoder` encoding data frame. /// /// # Brief -- Gitee From 7799165adf1b431719b0d9e61a40312b61f8646f Mon Sep 17 00:00:00 2001 From: Tiga Ultraman Date: Fri, 14 Jun 2024 14:22:29 +0800 Subject: [PATCH 5/8] redirect error fix Signed-off-by: Tiga Ultraman --- ylong_http/Cargo.toml | 4 +- ylong_http/src/body/mime/simple.rs | 134 +++++++++++------- ylong_http/src/body/mod.rs | 35 +++++ ylong_http/src/lib.rs | 10 +- ylong_http_client/src/async_impl/client.rs | 15 +- ylong_http_client/src/async_impl/request.rs | 18 +-- .../src/async_impl/uploader/mod.rs | 24 +++- ylong_http_client/src/lib.rs | 2 +- 8 files changed, 168 insertions(+), 74 deletions(-) diff --git a/ylong_http/Cargo.toml b/ylong_http/Cargo.toml index d231a03..21ecf0b 100644 --- a/ylong_http/Cargo.toml +++ b/ylong_http/Cargo.toml @@ -25,8 +25,8 @@ tokio_base = ["tokio"] # Uses asynchronous components of `tokio` ylong_base = ["ylong_runtime"] # Uses asynchronous components of `ylong` [dependencies] -tokio = { version = "1.20.1", features = ["io-util"], optional = true } -ylong_runtime = { git = "https://gitee.com/openharmony/commonlibrary_rust_ylong_runtime.git", optional = true } +tokio = { version = "1.20.1", features = ["io-util", "fs"], optional = true } +ylong_runtime = { git = "https://gitee.com/openharmony/commonlibrary_rust_ylong_runtime.git", features = ["fs", "sync"], optional = true } [dev-dependencies] tokio = { version = "1.20.1", features = ["io-util", "rt-multi-thread", "macros"] } diff --git a/ylong_http/src/body/mime/simple.rs b/ylong_http/src/body/mime/simple.rs index b5aa7b0..35c0864 100644 --- a/ylong_http/src/body/mime/simple.rs +++ b/ylong_http/src/body/mime/simple.rs @@ -13,12 +13,13 @@ // TODO: reuse mime later. +use std::future::Future; use std::io::Cursor; use std::pin::Pin; use std::task::{Context, Poll}; use std::vec::IntoIter; -use crate::body::async_impl::Body; +use crate::body::async_impl::{Body, ReusableReader}; use crate::{AsyncRead, ReadBuf}; /// A structure that helps you build a `multipart/form-data` message. @@ -174,10 +175,23 @@ impl MultiPart { states.push(MultiPartState::bytes( format!("--{}--\r\n", self.boundary).into_bytes(), )); - self.status = ReadStatus::Reading(MultiPartStates { - states: states.into_iter(), - curr: None, - }) + self.status = ReadStatus::Reading(MultiPartStates { states, index: 0 }) + } + + pub(crate) async fn reuse_inner(&mut self) -> std::io::Result<()> { + match std::mem::replace(&mut self.status, ReadStatus::Never) { + ReadStatus::Never => Ok(()), + ReadStatus::Reading(mut states) => { + let res = states.reuse().await; + self.status = ReadStatus::Reading(states); + res + } + ReadStatus::Finish(mut states) => { + states.reuse().await?; + self.status = ReadStatus::Reading(states); + Ok(()) + } + } } } @@ -196,7 +210,7 @@ impl AsyncRead for MultiPart { match self.status { ReadStatus::Never => self.build_status(), ReadStatus::Reading(_) => {} - ReadStatus::Finish => return Poll::Ready(Ok(())), + ReadStatus::Finish(_) => return Poll::Ready(Ok(())), } let status = if let ReadStatus::Reading(ref mut status) = self.status { @@ -213,7 +227,10 @@ impl AsyncRead for MultiPart { Poll::Ready(Ok(())) => { let new_filled = buf.filled().len(); if filled == new_filled { - self.status = ReadStatus::Finish; + match std::mem::replace(&mut self.status, ReadStatus::Never) { + ReadStatus::Reading(states) => self.status = ReadStatus::Finish(states), + _ => unreachable!(), + }; } Poll::Ready(Ok(())) } @@ -230,6 +247,23 @@ impl AsyncRead for MultiPart { } } +impl ReusableReader for MultiPart { + fn reuse<'a>( + &'a mut self, + ) -> Pin> + Send + Sync + 'a>> + where + Self: 'a, + { + Box::pin(async { + match self.status { + ReadStatus::Never => Ok(()), + ReadStatus::Reading(_) => self.reuse_inner().await, + ReadStatus::Finish(_) => self.reuse_inner().await, + } + }) + } +} + /// A structure that represents a part of `multipart/form-data` message. /// /// # Examples @@ -352,8 +386,8 @@ impl Part { /// Sets a stream body of this `Part`. /// /// The body message will be set to the body part. - pub fn stream(mut self, body: T) -> Self { - self.body = Some(MultiPartState::stream(Box::pin(body))); + pub fn stream(mut self, body: T) -> Self { + self.body = Some(MultiPartState::stream(Box::new(body))); self } } @@ -365,7 +399,7 @@ impl Default for Part { } /// A basic trait for MultiPart. -pub trait MultiPartBase: AsyncRead { +pub trait MultiPartBase: ReusableReader { /// Get reference of MultiPart. fn multipart(&self) -> &MultiPart; } @@ -379,12 +413,27 @@ impl MultiPartBase for MultiPart { enum ReadStatus { Never, Reading(MultiPartStates), - Finish, + Finish(MultiPartStates), } struct MultiPartStates { - states: IntoIter, - curr: Option, + states: Vec, + index: usize, +} + +impl MultiPartStates { + async fn reuse(&mut self) -> std::io::Result<()> { + self.index = 0; + for state in self.states.iter_mut() { + match state { + MultiPartState::Bytes(bytes) => bytes.set_position(0), + MultiPartState::Stream(stream) => { + stream.reuse().await?; + } + } + } + Ok(()) + } } impl MultiPartStates { @@ -393,11 +442,11 @@ impl MultiPartStates { cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { - let mut state = if let Some(state) = self.curr.take() { - state - } else { - return Poll::Ready(Ok(())); + let state = match self.states.get_mut(self.index) { + Some(state) => state, + None => return Poll::Ready(Ok(())), }; + match state { MultiPartState::Bytes(ref mut bytes) => { let filled_len = buf.filled().len(); @@ -406,39 +455,26 @@ impl MultiPartStates { let new = std::io::Read::read(bytes, unfilled).unwrap(); buf.set_filled(filled_len + new); - if new >= unfilled_len { - self.curr = Some(state); + if new < unfilled_len { + self.index += 1; } Poll::Ready(Ok(())) } - MultiPartState::Stream(ref mut stream) => { + MultiPartState::Stream(stream) => { let old_len = buf.filled().len(); - let result = stream.as_mut().poll_read(cx, buf); + let result = unsafe { Pin::new_unchecked(stream).poll_read(cx, buf) }; let new_len = buf.filled().len(); - self.poll_result(result, old_len, new_len, state) - } - } - } - - fn poll_result( - &mut self, - result: Poll>, - old_len: usize, - new_len: usize, - state: MultiPartState, - ) -> Poll> { - match result { - Poll::Ready(Ok(())) => { - if old_len != new_len { - self.curr = Some(state); + match result { + Poll::Ready(Ok(())) => { + if old_len == new_len { + self.index += 1; + } + Poll::Ready(Ok(())) + } + Poll::Pending => Poll::Pending, + x => x, } - Poll::Ready(Ok(())) } - Poll::Pending => { - self.curr = Some(state); - Poll::Pending - } - x => x, } } } @@ -451,13 +487,9 @@ impl AsyncRead for MultiPartStates { ) -> Poll> { let this = self.get_mut(); while !buf.initialize_unfilled().is_empty() { - if this.curr.is_none() { - this.curr = match this.states.next() { - None => break, - x => x, - } + if this.states.get(this.index).is_none() { + break; } - match this.poll_read_curr(cx, buf) { Poll::Ready(Ok(())) => {} x => return x, @@ -469,7 +501,7 @@ impl AsyncRead for MultiPartStates { enum MultiPartState { Bytes(Cursor>), - Stream(Pin>), + Stream(Box), } impl MultiPartState { @@ -477,7 +509,7 @@ impl MultiPartState { Self::Bytes(Cursor::new(bytes)) } - fn stream(reader: Pin>) -> Self { + fn stream(reader: Box) -> Self { Self::Stream(reader) } } diff --git a/ylong_http/src/body/mod.rs b/ylong_http/src/body/mod.rs index b71d862..dc87cc4 100644 --- a/ylong_http/src/body/mod.rs +++ b/ylong_http/src/body/mod.rs @@ -46,6 +46,7 @@ mod empty; mod mime; mod text; +pub use async_impl::ReusableReader; pub use chunk::{Chunk, ChunkBody, ChunkBodyDecoder, ChunkExt, ChunkState, Chunks}; pub use empty::EmptyBody; pub use mime::{ @@ -399,6 +400,40 @@ pub mod async_impl { Pin::new(&mut *fut.body).poll_data(cx, fut.buf) } } + + /// The reuse trait of request body. + pub trait ReusableReader: AsyncRead + Sync { + /// Reset body state, Ensure that the body can be re-read. + fn reuse<'a>( + &'a mut self, + ) -> Pin> + Send + Sync + 'a>> + where + Self: 'a; + } + + impl ReusableReader for crate::File { + fn reuse<'a>( + &'a mut self, + ) -> Pin> + Send + Sync + 'a>> + where + Self: 'a, + { + use crate::AsyncSeekExt; + + Box::pin(async { self.rewind().await.map(|_| ()) }) + } + } + + impl ReusableReader for &[u8] { + fn reuse<'a>( + &'a mut self, + ) -> Pin> + Send + Sync + 'a>> + where + Self: 'a, + { + Box::pin(async { Ok(()) }) + } + } } // Type definitions of the origin of the body data. diff --git a/ylong_http/src/lib.rs b/ylong_http/src/lib.rs index b99a3b3..6af6bfa 100644 --- a/ylong_http/src/lib.rs +++ b/ylong_http/src/lib.rs @@ -45,6 +45,12 @@ pub mod version; pub(crate) mod util; #[cfg(feature = "tokio_base")] -pub(crate) use tokio::io::{AsyncRead, AsyncReadExt, ReadBuf}; +pub(crate) use tokio::{ + fs::File, + io::{AsyncRead, AsyncReadExt, AsyncSeekExt, ReadBuf}, +}; #[cfg(feature = "ylong_base")] -pub(crate) use ylong_runtime::io::{AsyncRead, AsyncReadExt, ReadBuf}; +pub(crate) use ylong_runtime::{ + fs::File, + io::{AsyncRead, AsyncReadExt, AsyncSeekExt, ReadBuf}, +}; diff --git a/ylong_http_client/src/async_impl/client.rs b/ylong_http_client/src/async_impl/client.rs index d6eabc0..0606723 100644 --- a/ylong_http_client/src/async_impl/client.rs +++ b/ylong_http_client/src/async_impl/client.rs @@ -17,7 +17,7 @@ use ylong_http::request::uri::Uri; use super::pool::ConnPool; use super::timeout::TimeoutFuture; -use super::{conn, Body, Connector, HttpConnector, Request, Response}; +use super::{conn, Connector, HttpConnector, Request, Response}; use crate::async_impl::interceptor::{IdleInterceptor, Interceptor, Interceptors}; use crate::async_impl::request::Message; use crate::error::HttpClientError; @@ -36,7 +36,7 @@ use crate::util::redirect::{RedirectInfo, Trigger}; use crate::util::request::RequestArc; #[cfg(feature = "__c_openssl")] use crate::CertVerifier; -use crate::Retry; +use crate::{ErrorKind, Retry}; /// HTTP asynchronous client implementation. Users can use `async_impl::Client` /// to send `Request` asynchronously. @@ -147,7 +147,7 @@ impl Client { loop { let response = self.send_request(request.clone()).await; if let Err(ref err) = response { - if retries > 0 && request.ref_mut().body_mut().reuse() { + if retries > 0 && request.ref_mut().body_mut().reuse().await.is_ok() { self.interceptors.intercept_retry(err)?; retries -= 1; continue; @@ -217,9 +217,12 @@ impl Client { { Trigger::NextLink => { // Here the body should be reused. - if !request.ref_mut().body_mut().reuse() { - *request.ref_mut().body_mut() = Body::empty(); - } + request + .ref_mut() + .body_mut() + .reuse() + .await + .map_err(|e| HttpClientError::from_io_error(ErrorKind::Redirect, e))?; self.interceptors .intercept_redirect_request(request.ref_mut())?; response = self.send_unformatted_request(request.clone()).await?; diff --git a/ylong_http_client/src/async_impl/request.rs b/ylong_http_client/src/async_impl/request.rs index 1a61eaa..9c41c6e 100644 --- a/ylong_http_client/src/async_impl/request.rs +++ b/ylong_http_client/src/async_impl/request.rs @@ -19,6 +19,7 @@ use core::task::{Context, Poll}; use std::io::Cursor; use std::sync::Arc; +use ylong_http::body::async_impl::ReusableReader; use ylong_http::body::MultiPartBase; use ylong_http::request::uri::PercentEncoder as PerEncoder; use ylong_http::request::{Request as Req, RequestBuilder as ReqBuilder}; @@ -255,7 +256,7 @@ pub struct Body { pub(crate) enum BodyKind { Empty, Slice(Cursor>), - Stream(Box), + Stream(Box), Multipart(Box), } @@ -304,10 +305,10 @@ impl Body { /// ``` pub fn stream(stream: T) -> Self where - T: AsyncRead + Send + Sync + Unpin + 'static, + T: ReusableReader + Send + Sync + Unpin + 'static, { Body::new(BodyKind::Stream( - Box::new(stream) as Box + Box::new(stream) as Box )) } @@ -340,15 +341,15 @@ impl Body { Self { inner } } - // TODO: Considers reusing unread stream ? - pub(crate) fn reuse(&mut self) -> bool { + pub(crate) async fn reuse(&mut self) -> std::io::Result<()> { match self.inner { - BodyKind::Empty => true, + BodyKind::Empty => Ok(()), BodyKind::Slice(ref mut slice) => { slice.set_position(0); - true + Ok(()) } - _ => false, + BodyKind::Stream(ref mut stream) => stream.reuse().await, + BodyKind::Multipart(ref mut multipart) => multipart.reuse().await, } } } @@ -470,7 +471,6 @@ mod ut_client_request { .length(Some(4)), ); let mut request = RequestBuilder::default().body(Body::multipart(mp)).unwrap(); - assert!(!request.body_mut().reuse()); let handle = ylong_runtime::spawn(async move { let mut buf = vec![0u8; 50]; let mut v_size = vec![]; diff --git a/ylong_http_client/src/async_impl/uploader/mod.rs b/ylong_http_client/src/async_impl/uploader/mod.rs index 1900b88..068e042 100644 --- a/ylong_http_client/src/async_impl/uploader/mod.rs +++ b/ylong_http_client/src/async_impl/uploader/mod.rs @@ -14,11 +14,13 @@ mod builder; mod operator; +use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; pub use builder::{UploaderBuilder, WantsReader}; pub use operator::{Console, UploadOperator}; +use ylong_http::body::async_impl::ReusableReader; use ylong_http::body::{MultiPart, MultiPartBase}; use crate::runtime::{AsyncRead, ReadBuf}; @@ -90,7 +92,7 @@ pub struct Uploader { info: Option, } -impl Uploader { +impl Uploader { /// Creates an `Uploader` with a `Console` operator which displays process /// on console. /// @@ -123,7 +125,7 @@ impl Uploader<(), ()> { impl AsyncRead for Uploader where - R: AsyncRead + Unpin, + R: ReusableReader + Unpin, T: UploadOperator + Unpin, { fn poll_read( @@ -159,7 +161,23 @@ where } } -impl MultiPartBase for Uploader { +impl ReusableReader for Uploader +where + R: ReusableReader + Unpin, + T: UploadOperator + Unpin + Sync, +{ + fn reuse<'a>( + &'a mut self, + ) -> Pin> + Send + Sync + 'a>> + where + Self: 'a, + { + self.info = None; + self.reader.reuse() + } +} + +impl MultiPartBase for Uploader { fn multipart(&self) -> &MultiPart { &self.reader } diff --git a/ylong_http_client/src/lib.rs b/ylong_http_client/src/lib.rs index 2f687a9..925c9ac 100644 --- a/ylong_http_client/src/lib.rs +++ b/ylong_http_client/src/lib.rs @@ -22,7 +22,7 @@ // ylong_http crate re-export. #[cfg(any(feature = "ylong_base", feature = "tokio_base"))] -pub use ylong_http::body::{EmptyBody, TextBody}; +pub use ylong_http::body::{EmptyBody, ReusableReader, TextBody}; pub use ylong_http::headers::{ Header, HeaderName, HeaderValue, HeaderValueIter, HeaderValueIterMut, Headers, HeadersIntoIter, HeadersIter, HeadersIterMut, -- Gitee From 95b4af405351162328b4911fdcaf41e7a5bc08e2 Mon Sep 17 00:00:00 2001 From: Tiga Ultraman Date: Sat, 31 Aug 2024 17:41:01 +0800 Subject: [PATCH 6/8] Perfect uri character check logic and recover two h2 encoder tests Signed-off-by: Tiga Ultraman --- ylong_http/src/h2/encoder.rs | 95 +++++++++++++++++++++++++++++++ ylong_http/src/request/uri/mod.rs | 73 ++++++++++++++++++------ 2 files changed, 152 insertions(+), 16 deletions(-) diff --git a/ylong_http/src/h2/encoder.rs b/ylong_http/src/h2/encoder.rs index a29043d..adfbef3 100644 --- a/ylong_http/src/h2/encoder.rs +++ b/ylong_http/src/h2/encoder.rs @@ -1823,6 +1823,101 @@ mod ut_frame_encoder { } } + /// UT test cases for `FrameEncoder` encoding continuation frames. + /// + /// # Brief + /// 1. Creates a `FrameEncoder`. + /// 2. Creates a `Frame` with `Payload::Headers` and sets the flags. + /// 3. Sets the frame for the encoder. + /// 4. Encodes the continuation frames using a buffer. + /// 5. Checks whether the result is correct. + #[test] + fn ut_encode_continuation_frames() { + let mut frame_encoder = FrameEncoder::new(4096, false); + let mut new_parts = Parts::new(); + assert!(new_parts.is_empty()); + new_parts.pseudo.set_method(Some("GET".to_string())); + new_parts.pseudo.set_scheme(Some("https".to_string())); + new_parts.pseudo.set_path(Some("/code".to_string())); + new_parts + .pseudo + .set_authority(Some("example.com".to_string())); + + let mut frame_flag = FrameFlags::empty(); + frame_flag.set_end_headers(true); + frame_flag.set_end_stream(false); + let frame = Frame::new( + 1, + frame_flag.clone(), + Payload::Headers(Headers::new(new_parts.clone())), + ); + + frame_encoder.set_frame(frame).unwrap(); + frame_encoder.state = FrameEncoderState::EncodingContinuationFrames; + let mut buf = [0u8; 5000]; + + assert!(frame_encoder.encode_continuation_frames(&mut buf).is_ok()); + + let mut frame_flag = FrameFlags::empty(); + frame_flag.set_end_headers(true); + let frame = Frame::new( + 1, + frame_flag, + Payload::Headers(Headers::new(new_parts.clone())), + ); + + frame_encoder.set_frame(frame).unwrap(); + frame_encoder.state = FrameEncoderState::EncodingContinuationFrames; + assert!(frame_encoder.encode_continuation_frames(&mut buf).is_ok()); + + let mut frame_flag = FrameFlags::empty(); + frame_flag.set_end_headers(true); + let frame = Frame::new(1, frame_flag, Payload::Ping(Ping::new([0; 8]))); + + frame_encoder.set_frame(frame).unwrap(); + frame_encoder.state = FrameEncoderState::EncodingContinuationFrames; + assert!(frame_encoder.encode_continuation_frames(&mut buf).is_err()); + } + + /// UT test cases for `FrameEncoder` encoding padded data. + /// + /// # Brief + /// 1. Creates a `FrameEncoder`. + /// 2. Creates a `Frame` with `Payload::Data` and sets the flags. + /// 3. Sets the frame for the encoder. + /// 4. Encodes the padding using a buffer. + /// 5. Checks whether the result is correct. + #[test] + fn ut_encode_padding() { + let mut frame_encoder = FrameEncoder::new(4096, false); + + // Creates a padded data frame. + let mut frame_flags = FrameFlags::empty(); + frame_flags.set_end_headers(true); + frame_flags.set_padded(true); + let data_payload = vec![0u8; 500]; + let data_frame = Frame::new( + 1, + frame_flags.clone(), + Payload::Data(Data::new(data_payload)), + ); + + // Sets the frame to the frame_encoder and test padding encoding. + frame_encoder.set_frame(data_frame).unwrap(); + frame_encoder.state = FrameEncoderState::EncodingDataPadding; + let mut buf = [0u8; 600]; + assert!(frame_encoder.encode_padding(&mut buf).is_ok()); + + let headers_payload = Payload::Headers(Headers::new(Parts::new())); + let headers_frame = Frame::new(1, frame_flags.clone(), headers_payload); + frame_encoder.set_frame(headers_frame).unwrap(); + frame_encoder.state = FrameEncoderState::EncodingDataPadding; + assert!(frame_encoder.encode_padding(&mut buf).is_err()); + + frame_encoder.current_frame = None; + assert!(frame_encoder.encode_padding(&mut buf).is_err()); + } + /// UT test cases for `FrameEncoder` encoding data frame. /// /// # Brief diff --git a/ylong_http/src/request/uri/mod.rs b/ylong_http/src/request/uri/mod.rs index 6358715..78818e2 100644 --- a/ylong_http/src/request/uri/mod.rs +++ b/ylong_http/src/request/uri/mod.rs @@ -90,7 +90,7 @@ pub struct Uri { } impl Uri { - /// Creates a HTTP-compliant default `Uri` with `Path` set to '/'. + /// Creates an HTTP-compliant default `Uri` with `Path` set to '/'. pub(crate) fn http() -> Uri { Uri { scheme: None, @@ -232,7 +232,12 @@ impl Uri { let (scheme, rest) = scheme_token(bytes)?; let (authority, rest) = authority_token(rest)?; let (path, rest) = path_token(rest)?; - let query = query_token(rest)?; + let query = match rest.first() { + None => None, + Some(&b'?') => query_token(&rest[1..])?, + Some(&b'#') => None, + _ => return Err(InvalidUri::UriMissQuery.into()), + }; let result = Uri { scheme, authority, @@ -1084,7 +1089,8 @@ fn path_token(bytes: &[u8]) -> Result<(Option, &[u8]), InvalidUri> { break; } _ => { - if !URI_VALUE_BYTES[b as usize] { + // "{} The three characters that might be used were previously percent-encoding. + if !PATH_AND_QUERY_BYTES[b as usize] { return Err(InvalidUri::InvalidByte); } } @@ -1098,15 +1104,10 @@ fn path_token(bytes: &[u8]) -> Result<(Option, &[u8]), InvalidUri> { } } -fn query_token(s: &[u8]) -> Result, InvalidUri> { - if s.is_empty() { +fn query_token(bytes: &[u8]) -> Result, InvalidUri> { + if bytes.is_empty() { return Ok(None); } - let bytes = if s[0].eq_ignore_ascii_case(&b'?') { - &s[1..] - } else { - s - }; let mut end = bytes.len(); for (i, &b) in bytes.iter().enumerate() { match b { @@ -1114,8 +1115,10 @@ fn query_token(s: &[u8]) -> Result, InvalidUri> { end = i; break; } + // ?| ` | { | } + 0x3F | 0x60 | 0x7B | 0x7D => {} _ => { - if !URI_VALUE_BYTES[b as usize] { + if !PATH_AND_QUERY_BYTES[b as usize] { return Err(InvalidUri::InvalidByte); } } @@ -1205,6 +1208,38 @@ const URI_VALUE_BYTES: [bool; 256] = { ] }; +#[rustfmt::skip] +const PATH_AND_QUERY_BYTES: [bool; 256] = { + const __: bool = false; + const TT: bool = true; + [ +// \0 HT LF CR + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // F + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 1F +// \w ! " # $ % & ' ( ) * + , - . / + __, TT, __, __, TT, TT, TT, TT, TT, TT, TT, TT, TT, TT, TT, TT, // 2F +// 0 1 2 3 4 5 6 7 8 9 : ; < = > ? + TT, TT, TT, TT, TT, TT, TT, TT, TT, TT, TT, TT, __, TT, __, __, // 3F +// @ A B C D E F G H I J K L M N O + TT, TT, TT, TT, TT, TT, TT, TT, TT, TT, TT, TT, TT, TT, TT, TT, // 4F +// P Q R S T U V W X Y Z [ \ ] ^ _ + TT, TT, TT, TT, TT, TT, TT, TT, TT, TT, TT, TT, TT, TT, TT, TT, // 5F +// ` a b c d e f g h i j k l m n o + __, TT, TT, TT, TT, TT, TT, TT, TT, TT, TT, TT, TT, TT, TT, TT, // 6F +// p q r s t u v w x y z { | } ~ del + TT, TT, TT, TT, TT, TT, TT, TT, TT, TT, TT, __, TT, __, TT, __, // 7F +// Expand ascii + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 8F + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 9F + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // AF + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // BF + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // CF + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // DF + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // EF + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // FF + ] +}; + #[cfg(test)] mod ut_uri { use super::{InvalidUri, Scheme, Uri, UriBuilder}; @@ -1417,11 +1452,6 @@ mod ut_uri { Err(HttpError::from(ErrorKind::Uri(InvalidUri::InvalidByte))), ); - uri_test_case!( - br"https://www.example.com:80/message/email?name='\^'", - Err(HttpError::from(ErrorKind::Uri(InvalidUri::InvalidByte))), - ); - uri_test_case!( br#"https:/www.example.com:80/message/email?name=arya"#, Err(HttpError::from(ErrorKind::Uri(InvalidUri::UriMissScheme))), @@ -1611,4 +1641,15 @@ mod ut_uri { let uri = Uri::from_bytes(b"http://example.com:8080").unwrap(); assert_eq!(uri.path_and_query(), None); } + + /// UT test cases for `Uri::path_and_query`. + /// + /// # Brief + /// 1. Creates Uri by calling `Uri::path_and_query()`. + /// 2. Checks that the query containing the {} symbol parses properly. + #[test] + fn ut_uri_json_query() { + let uri = Uri::from_bytes(b"http://example.com:8080/foo?a=1{WEBO_TEST}").unwrap(); + assert_eq!(uri.path_and_query().unwrap(), "/foo?a=1{WEBO_TEST}"); + } } -- Gitee From 9b45245880a9ab68792057852edc012da3db42f3 Mon Sep 17 00:00:00 2001 From: huaxin Date: Sat, 14 Sep 2024 14:43:49 +0800 Subject: [PATCH 7/8] 1.ylong_http/ylong_http_client add h3 part 2.add deps to boringssl and quiche Signed-off-by: huaxin Change-Id: I4c70d47b32b9797478ab31d4da5b9a4256e6e809 --- ylong_http/src/error.rs | 13 + ylong_http/src/h2/decoder.rs | 4 +- ylong_http/src/h3/decoder.rs | 793 +++++++++++++-- ylong_http/src/h3/encoder.rs | 944 ++++++++++-------- ylong_http/src/h3/error.rs | 188 ++++ ylong_http/src/h3/frame.rs | 585 +++++------ ylong_http/src/h3/mod.rs | 26 +- ylong_http/src/h3/octets.rs | 527 +++++----- ylong_http/src/h3/parts.rs | 32 +- ylong_http/src/h3/pseudo.rs | 2 +- ylong_http/src/h3/qpack/decoder.rs | 922 +++++------------ ylong_http/src/h3/qpack/encoder.rs | 689 ++++--------- ylong_http/src/h3/qpack/error.rs | 13 +- ylong_http/src/h3/qpack/format/decoder.rs | 71 +- ylong_http/src/h3/qpack/format/encoder.rs | 716 ++++++------- ylong_http/src/h3/qpack/integer.rs | 33 +- ylong_http/src/h3/qpack/mod.rs | 110 +- ylong_http/src/h3/qpack/table.rs | 898 +++++++++-------- ylong_http/src/h3/stream.rs | 125 +++ ylong_http/src/request/uri/mod.rs | 47 +- ylong_http/src/response/status.rs | 1 - ylong_http_client/Cargo.toml | 4 +- ylong_http_client/build.rs | 2 +- ylong_http_client/src/async_impl/client.rs | 115 ++- .../src/async_impl/conn/http2.rs | 3 +- .../src/async_impl/conn/http3.rs | 320 ++++++ ylong_http_client/src/async_impl/conn/mod.rs | 6 + .../src/async_impl/connector/mod.rs | 107 +- .../src/async_impl/connector/stream.rs | 26 +- .../src/async_impl/interceptor/mod.rs | 2 +- .../src/async_impl/{ssl_stream => }/mix.rs | 31 +- ylong_http_client/src/async_impl/mod.rs | 9 +- ylong_http_client/src/async_impl/pool.rs | 168 +++- ylong_http_client/src/async_impl/quic/mod.rs | 294 ++++++ .../src/async_impl/ssl_stream/mod.rs | 6 +- ylong_http_client/src/error.rs | 10 +- ylong_http_client/src/sync_impl/client.rs | 21 - ylong_http_client/src/sync_impl/ssl_stream.rs | 2 +- ylong_http_client/src/util/alt_svc.rs | 138 +++ .../src/util/c_openssl/adapter.rs | 51 +- ylong_http_client/src/util/c_openssl/error.rs | 18 +- .../src/util/c_openssl/ffi/err.rs | 4 +- .../src/util/c_openssl/ffi/mod.rs | 2 +- .../src/util/c_openssl/ffi/ssl.rs | 25 +- .../src/util/c_openssl/ffi/x509.rs | 7 + ylong_http_client/src/util/c_openssl/mod.rs | 5 +- .../src/util/c_openssl/ssl/ctx.rs | 70 +- .../src/util/c_openssl/ssl/mod.rs | 2 + .../src/util/c_openssl/ssl/ssl_base.rs | 22 +- .../src/util/c_openssl/ssl/stream.rs | 6 +- ylong_http_client/src/util/c_openssl/x509.rs | 25 +- .../src/util/config/connector.rs | 4 +- ylong_http_client/src/util/config/http.rs | 116 ++- ylong_http_client/src/util/config/mod.rs | 12 +- .../src/util/{h2 => }/data_ref.rs | 12 +- ylong_http_client/src/util/dispatcher.rs | 281 ++++++ ylong_http_client/src/util/h2/mod.rs | 2 - ylong_http_client/src/util/h2/streams.rs | 7 +- ylong_http_client/src/util/h3/io_manager.rs | 225 +++++ .../src/util/h3/mod.rs | 6 + .../src/util/h3/stream_manager.rs | 774 ++++++++++++++ ylong_http_client/src/util/h3/streams.rs | 708 +++++++++++++ ylong_http_client/src/util/mod.rs | 12 +- ylong_http_client/src/util/redirect.rs | 1 + .../tests/sdv_async_client_build.rs | 1 - .../tests/sdv_async_https_pinning.rs | 2 +- 66 files changed, 7018 insertions(+), 3385 deletions(-) create mode 100644 ylong_http_client/src/async_impl/conn/http3.rs rename ylong_http_client/src/async_impl/{ssl_stream => }/mix.rs (74%) create mode 100644 ylong_http_client/src/async_impl/quic/mod.rs create mode 100644 ylong_http_client/src/util/alt_svc.rs rename ylong_http_client/src/util/{h2 => }/data_ref.rs (81%) create mode 100644 ylong_http_client/src/util/h3/io_manager.rs rename ylong_http/src/h3/connection.rs => ylong_http_client/src/util/h3/mod.rs (84%) create mode 100644 ylong_http_client/src/util/h3/stream_manager.rs create mode 100644 ylong_http_client/src/util/h3/streams.rs diff --git a/ylong_http/src/error.rs b/ylong_http/src/error.rs index 591c62c..fc0cb25 100644 --- a/ylong_http/src/error.rs +++ b/ylong_http/src/error.rs @@ -28,6 +28,8 @@ use std::error::Error; use crate::h1::H1Error; #[cfg(feature = "http2")] use crate::h2::H2Error; +#[cfg(feature = "http3")] +use crate::h3::H3Error; use crate::request::uri::InvalidUri; /// Errors that may occur when using this crate. @@ -55,6 +57,13 @@ impl From for HttpError { } } +#[cfg(feature = "http3")] +impl From for HttpError { + fn from(err: H3Error) -> Self { + ErrorKind::H3(err).into() + } +} + impl From for HttpError { fn from(_value: Infallible) -> Self { unreachable!() @@ -84,4 +93,8 @@ pub(crate) enum ErrorKind { /// Errors related to `HTTP/2`. #[cfg(feature = "http2")] H2(H2Error), + + /// Errors related to `HTTP/2`. + #[cfg(feature = "http3")] + H3(H3Error), } diff --git a/ylong_http/src/h2/decoder.rs b/ylong_http/src/h2/decoder.rs index ca1bd72..c33a3a0 100644 --- a/ylong_http/src/h2/decoder.rs +++ b/ylong_http/src/h2/decoder.rs @@ -150,9 +150,9 @@ impl core::iter::IntoIterator for Frames { /// When Headers Frames or Continuation Frames are not End Headers, they are /// represented as `FrameKind::Partial`. pub enum FrameKind { - /// PUSH_PROMISE or HEADRS frame parsing completed. + /// PUSH_PROMISE or HEADERS frame parsing completed. Complete(Frame), - /// Partial decoded of PUSH_PROMISE or HEADRS frame. + /// Partial decoded of PUSH_PROMISE or HEADERS frame. Partial, } diff --git a/ylong_http/src/h3/decoder.rs b/ylong_http/src/h3/decoder.rs index 890f2cd..7792315 100644 --- a/ylong_http/src/h3/decoder.rs +++ b/ylong_http/src/h3/decoder.rs @@ -11,76 +11,757 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::h3::frame::Headers; +use std::collections::HashMap; +use std::mem::take; + +use crate::h3::error::CommonError::FieldMissing; +use crate::h3::error::DecodeError::{FrameSizeError, UnexpectedFrame, UnsupportedSetting}; +use crate::h3::error::{CommonError, DecodeError, H3Error}; +use crate::h3::frame::{ + CancelPush, Data, GoAway, Headers, MaxPushId, Payload, PushPromise, Settings, DATA_FRAME_TYPE, + HEADERS_FRAME_TYPE, PUSH_PROMISE_FRAME_TYPE, SETTINGS_FRAME_TYPE, +}; +use crate::h3::octets::{ReadableBytes, WritableBytes}; use crate::h3::parts::Parts; -use crate::h3::qpack::error::H3Error_QPACK; +use crate::h3::qpack::error::QpackError; use crate::h3::qpack::table::DynamicTable; -use crate::h3::qpack::{FiledLines, QpackDecoder}; - -pub struct FrameDecoder<'a> { - qpack_decoder: QpackDecoder<'a>, - headers: Parts, - qpack_encoder_buffer: Vec, - remaining_qpack_payload: usize, - stream_id: usize, +use crate::h3::qpack::{FieldDecodeState, FiledLines, QpackDecoder}; +use crate::h3::stream::StreamMessage::Request; +use crate::h3::stream::{FrameKind, Frames, StreamMessage}; +use crate::h3::{frame, is_bidirectional, stream, Frame, H3ErrorCode}; + +/// HTTP3 stream bytes sequence decoder. +/// The http3 stream decoder deserializes stream data into readable structured +/// data, including stream type, Frame, etc. +/// +/// # Examples +/// +/// ``` +/// use ylong_http::h3::FrameDecoder; +/// +/// let mut decoder = FrameDecoder::new(100, 10240); +/// let data_frame_bytes = &[0, 5, b'h', b'e', b'l', b'l', b'o']; +/// let message = decoder.decode(0, data_frame_bytes).unwrap(); +/// ``` +pub struct FrameDecoder { + qpack_decoder: QpackDecoder, + streams: HashMap, +} + +#[derive(Copy, Clone)] +enum DecodeState { + StreamType, + PushId, + FrameType, + PayloadLen, + HeadersPayload, + DataPayload, + SettingsPayload, + VariablePayload, + PushPromisePayload, + UnknownPayload, + QpackDecoderInst, + QpackEncoderInst, + DropUnknown, +} + +#[derive(Copy, Clone, Eq, PartialEq)] +enum StreamType { + Init, + Request, + Control, + Push, + QpackEncoder, + QpackDecoder, + Unknown, +} + +struct DecodedH3Stream { + ty: StreamType, + state: DecodeState, + buffer: Vec, + offset: usize, + push_id: Option, + frame_type: Option, + push_frame_id: Option, + payload_len: Option, + stream_set: bool, +} + +enum DecodePartRes { + ReturnOuter, + Continue, +} + +impl FrameDecoder { + /// `FrameDecoder` constructor. max_blocked_streams is the maximum number of + /// stream blocks allowed by qpack, and max_table_capacity is the + /// maximum dynamic table capacity allowed by the encoder. + pub fn new(max_blocked_streams: usize, max_table_capacity: usize) -> Self { + Self { + qpack_decoder: QpackDecoder::new(max_blocked_streams, max_table_capacity), + streams: HashMap::new(), + } + } + + /// Sets allowed_max_field_section_size Setting. Only one call is allowed, + /// and the max_field_section_size needs to be sent to the peer through the + /// Settings frame + pub fn local_allowed_max_field_section_size(&mut self, size: usize) { + self.qpack_decoder.set_max_field_section_size(size) + } + + /// The Decoder sends the Stream Cancellation instruction to actively cancel + /// the stream. + pub fn cancel_stream(&mut self, stream_id: u64, buf: &mut [u8]) -> Result { + self.streams.remove(&stream_id); + self.qpack_decoder + .stream_cancel(stream_id, buf) + .map_err(|e| DecodeError::QpackError(e).into()) + } + + /// Cleans the stream information when the stream normally ends. + pub fn finish_stream(&mut self, id: u64) -> Result<(), H3Error> { + if is_bidirectional(id) { + self.qpack_decoder + .finish_stream(id) + .map_err(|e| H3Error::Decode(e.into()))?; + } + self.streams.remove(&id); + Ok(()) + } + + /// Deserializes stream data into readable structured data, + /// including stream type, Frame, etc. + /// + /// # Examples + /// + /// ``` + /// use ylong_http::h3::FrameDecoder; + /// + /// let mut decoder = FrameDecoder::new(100, 10240); + /// let data_frame_bytes = &[0, 5, b'h', b'e', b'l', b'l', b'o']; + /// let message = decoder.decode(0, data_frame_bytes).unwrap(); + /// ``` + pub fn decode(&mut self, id: u64, src: &[u8]) -> Result { + let mut stream = if let Some(stream) = self.streams.remove(&id) { + stream + } else { + DecodedH3Stream::new(id) + }; + stream.buffer.extend_from_slice(src); + let mut frames = Frames::new(); + loop { + match stream.decode_state() { + DecodeState::StreamType => { + if let DecodePartRes::ReturnOuter = stream.decode_stream_type()? { + self.streams.insert(id, stream); + return Ok(StreamMessage::WaitingMore); + } + } + DecodeState::PushId => { + if let DecodePartRes::ReturnOuter = stream.decode_push_id()? { + self.streams.insert(id, stream); + return Ok(StreamMessage::WaitingMore); + } + } + // The StreamType branch ensures that only Request/Control/Push can go to the + // FrameType branch. + DecodeState::FrameType => { + if let DecodePartRes::ReturnOuter = stream.decode_frame_type(&mut frames)? { + let message = stream.return_by_type(frames)?; + self.streams.insert(id, stream); + return Ok(message); + } + } + DecodeState::PayloadLen => { + if let DecodePartRes::ReturnOuter = stream.decode_payload_len(&mut frames)? { + let message = stream.return_by_type(frames)?; + self.streams.insert(id, stream); + return Ok(message); + } + } + DecodeState::DataPayload => { + if let DecodePartRes::ReturnOuter = stream.decode_data_payload(&mut frames)? { + let message = stream.return_by_type(frames)?; + self.streams.insert(id, stream); + return Ok(message); + } + } + DecodeState::HeadersPayload => { + if let DecodePartRes::ReturnOuter = + stream.decode_headers_payload(&mut frames, &mut self.qpack_decoder, id)? + { + let message = stream.return_by_type(frames)?; + self.streams.insert(id, stream); + return Ok(message); + } + } + DecodeState::VariablePayload => { + if let DecodePartRes::ReturnOuter = + stream.decode_variable_payload(&mut frames)? + { + let message = stream.return_by_type(frames)?; + self.streams.insert(id, stream); + return Ok(message); + } + } + DecodeState::SettingsPayload => { + if let DecodePartRes::ReturnOuter = + stream.decode_settings_payload(&mut frames)? + { + let message = stream.return_by_type(frames)?; + self.streams.insert(id, stream); + return Ok(message); + } + } + DecodeState::PushPromisePayload => { + if let DecodePartRes::ReturnOuter = stream.decode_push_payload(&mut frames)? { + let message = stream.return_by_type(frames)?; + self.streams.insert(id, stream); + return Ok(message); + } + } + DecodeState::UnknownPayload => { + if let DecodePartRes::ReturnOuter = + stream.decode_unknown_payload(&mut frames)? + { + let message = stream.return_by_type(frames)?; + self.streams.insert(id, stream); + return Ok(message); + } + } + DecodeState::QpackDecoderInst => { + let reader = ReadableBytes::from(&stream.buffer.as_slice()[stream.offset..]); + let inst = Vec::from(reader.remaining()); + stream.clear_buffer(); + self.streams.insert(id, stream); + return Ok(StreamMessage::QpackDecoder(inst)); + } + DecodeState::QpackEncoderInst => { + let reader = ReadableBytes::from(&stream.buffer.as_slice()[stream.offset..]); + let unblocked = self + .qpack_decoder + .decode_ins(reader.remaining()) + .map_err(DecodeError::QpackError)?; + stream.clear_buffer(); + self.streams.insert(id, stream); + return Ok(StreamMessage::QpackEncoder(unblocked)); + } + DecodeState::DropUnknown => { + stream.clear_buffer(); + return Ok(StreamMessage::Unknown); + } + } + } + } +} + +impl StreamType { + pub(crate) fn is_request(&self) -> bool { + *self == StreamType::Request + } + + pub(crate) fn is_control(&self) -> bool { + *self == StreamType::Control + } + + pub(crate) fn is_push(&self) -> bool { + *self == StreamType::Push + } } -impl<'a> FrameDecoder<'a> { - pub(crate) fn new( - field_list_size: usize, - table: &'a mut DynamicTable, - stream_id: usize, - ) -> Self { - let frame_decoder = Self { - qpack_decoder: QpackDecoder::new(field_list_size, table), - headers: Parts::new(), - qpack_encoder_buffer: vec![0; 16383], - remaining_qpack_payload: 0, - stream_id: stream_id, +impl DecodedH3Stream { + pub(crate) fn new(id: u64) -> Self { + const DECODED_BUFFER_SIZE: usize = 1024; + let (ty, state) = if is_bidirectional(id) { + (StreamType::Request, DecodeState::FrameType) + } else { + (StreamType::Init, DecodeState::StreamType) + }; + Self { + ty, + state, + // TODO a property size. + buffer: Vec::with_capacity(DECODED_BUFFER_SIZE), + offset: 0, + push_id: None, + frame_type: None, + push_frame_id: None, + payload_len: None, + stream_set: false, + } + } + + pub(crate) fn decode_state(&self) -> DecodeState { + self.state + } + + pub(crate) fn set_decode_state(&mut self, state: DecodeState) { + self.state = state + } + + fn init_state_by_type(&mut self) { + let next_state = match self.ty { + StreamType::Control => DecodeState::FrameType, + StreamType::Push => DecodeState::PushId, + StreamType::QpackEncoder => DecodeState::QpackEncoderInst, + StreamType::QpackDecoder => DecodeState::QpackDecoderInst, + StreamType::Unknown => DecodeState::DropUnknown, + _ => unreachable!(), + }; + self.set_decode_state(next_state); + } + + fn stream_type(&self) -> StreamType { + self.ty + } + + fn curr_frame_type(&self) -> Option { + self.frame_type + } + + fn set_stream_type(&mut self, ty: StreamType) { + self.ty = ty + } + + fn is_set(&self) -> bool { + self.stream_set + } + + fn remain_payload_len(&self) -> u64 { + self.payload_len.unwrap_or(0) + } + + fn subtract_payload_len(&mut self, off: usize) -> Result<(), H3Error> { + match self.payload_len { + None => Err(CommonError::FieldMissing.into()), + Some(curr) => { + let (remain, overflow) = curr.overflowing_sub(off as u64); + if overflow { + Err(CommonError::CalculateOverflow.into()) + } else { + self.payload_len = Some(remain); + Ok(()) + } + } + } + } + + fn push_id(&self) -> Result { + self.push_id.ok_or(CommonError::FieldMissing.into()) + } + + fn clear_frame(&mut self) { + self.frame_type = None; + self.payload_len = None; + } + + fn clear_buffer(&mut self) { + self.buffer.clear(); + self.offset = 0; + } + + fn return_by_type(&self, frames: Frames) -> Result { + match self.stream_type() { + StreamType::Request => Ok(StreamMessage::Request(frames)), + StreamType::Push => { + let push_id = self.push_id()?; + Ok(StreamMessage::Push(push_id, frames)) + } + StreamType::Control => Ok(StreamMessage::Control(frames)), + _ => { + // Note: unreachable + Err(UnexpectedFrame(self.frame_type.unwrap()).into()) + } + } + } + + fn decode_stream_type(&mut self) -> Result { + let mut reader = ReadableBytes::from(&self.buffer.as_slice()[self.offset..]); + match reader.get_varint() { + Ok(integer) => { + self.offset += reader.index(); + let ty = decode_type(integer); + self.set_stream_type(ty); + self.init_state_by_type(); + Ok(DecodePartRes::Continue) + } + // Byte shortage + Err(_) => { + self.set_decode_state(DecodeState::StreamType); + Ok(DecodePartRes::ReturnOuter) + } + } + } + + fn decode_push_id(&mut self) -> Result { + let mut reader = ReadableBytes::from(&self.buffer.as_slice()[self.offset..]); + match reader.get_varint() { + // TODO enable push + Ok(integer) => { + self.push_id = Some(integer); + self.offset += reader.index(); + self.set_decode_state(DecodeState::FrameType); + Ok(DecodePartRes::Continue) + } + // Byte shortage + Err(_) => { + self.set_decode_state(DecodeState::PushId); + Ok(DecodePartRes::ReturnOuter) + } + } + } + + fn decode_frame_type(&mut self, frames: &mut Frames) -> Result { + let mut reader = ReadableBytes::from(&self.buffer.as_slice()[self.offset..]); + match reader.get_varint() { + Ok(integer) => { + match integer { + frame::DATA_FRAME_TYPE | frame::HEADERS_FRAME_TYPE => { + if !(self.stream_type().is_request() || self.stream_type().is_push()) { + return Err(DecodeError::UnexpectedFrame(integer).into()); + } + } + frame::PUSH_PROMISE_FRAME_TYPE => { + if !self.stream_type().is_request() { + return Err(DecodeError::UnexpectedFrame(integer).into()); + } + } + frame::SETTINGS_FRAME_TYPE => { + if !self.stream_type().is_control() || self.is_set() { + return Err(DecodeError::UnexpectedFrame(integer).into()); + } + self.stream_set = true; + } + frame::CANCEL_PUSH_FRAME_TYPE => { + if !self.stream_type().is_control() { + return Err(DecodeError::UnexpectedFrame(integer).into()); + } + } + frame::MAX_PUSH_ID_FRAME_TYPE => { + return Err(DecodeError::UnexpectedFrame(integer).into()) + } + _ => {} + } + self.frame_type = Some(integer); + self.offset += reader.index(); + self.set_decode_state(DecodeState::PayloadLen); + Ok(DecodePartRes::Continue) + } + // Byte shortage + Err(_) => { + self.set_decode_state(DecodeState::FrameType); + frames.push(FrameKind::Partial); + Ok(DecodePartRes::ReturnOuter) + } + } + } + + fn decode_payload_len(&mut self, frames: &mut Frames) -> Result { + let mut reader = ReadableBytes::from(&self.buffer.as_slice()[self.offset..]); + match reader.get_varint() { + Ok(integer) => { + self.payload_len = Some(integer); + self.offset += reader.index(); + match self.curr_frame_type() { + None => { + unreachable!() + } + Some(DATA_FRAME_TYPE) => { + self.set_decode_state(DecodeState::DataPayload); + } + Some(HEADERS_FRAME_TYPE) => { + self.set_decode_state(DecodeState::HeadersPayload); + } + Some(SETTINGS_FRAME_TYPE) => { + self.set_decode_state(DecodeState::SettingsPayload); + } + Some(PUSH_PROMISE_FRAME_TYPE) => { + self.set_decode_state(DecodeState::PushPromisePayload); + } + Some(frame::GOAWAY_FRAME_TYPE) => { + self.set_decode_state(DecodeState::VariablePayload); + } + Some(frame::MAX_PUSH_ID_FRAME_TYPE) => { + self.set_decode_state(DecodeState::VariablePayload); + } + Some(frame::CANCEL_PUSH_FRAME_TYPE) => { + self.set_decode_state(DecodeState::VariablePayload); + } + _ => { + self.set_decode_state(DecodeState::UnknownPayload); + } + } + } + // Byte shortage + Err(_) => { + frames.push(FrameKind::Partial); + return Ok(DecodePartRes::ReturnOuter); + } + } + Ok(DecodePartRes::Continue) + } + + fn decode_data_payload(&mut self, frames: &mut Frames) -> Result { + let mut reader = ReadableBytes::from(&self.buffer.as_slice()[self.offset..]); + + // Note: None is impossible for `stream.payload_len`. + let payload_len = self.remain_payload_len() as usize; + if reader.cap() < payload_len { + let frame = Frame::new( + DATA_FRAME_TYPE, + Payload::Data(Data::new(Vec::from(reader.remaining()))), + ); + self.subtract_payload_len(reader.cap())?; + self.clear_buffer(); + frames.push(FrameKind::Complete(Box::from(frame))); + frames.push(FrameKind::Partial); + Ok(DecodePartRes::ReturnOuter) + } else { + let frame = Frame::new( + DATA_FRAME_TYPE, + Payload::Data(Data::new(Vec::from(reader.slice(payload_len)?))), + ); + frames.push(FrameKind::Complete(Box::from(frame))); + self.offset += payload_len; + let remaining = reader.cap(); + self.clear_frame(); + self.set_decode_state(DecodeState::FrameType); + if remaining == 0 { + self.clear_buffer(); + Ok(DecodePartRes::ReturnOuter) + } else { + Ok(DecodePartRes::Continue) + } + } + } + + fn get_qpack_decoded_header( + &mut self, + frames: &mut Frames, + qpack_decoder: &mut QpackDecoder, + id: u64, + remaining: usize, + ) -> Result { + let mut ins_buf = Vec::new(); + // TODO id can be u64. + let (part, len) = qpack_decoder + .finish(id, &mut ins_buf) + .map_err(DecodeError::QpackError)?; + let frame = match self.curr_frame_type() { + Some(HEADERS_FRAME_TYPE) => { + let mut headers_payload = Headers::new(part); + if len.is_some() { + headers_payload.set_instruction(ins_buf); + }; + Frame::new(HEADERS_FRAME_TYPE, Payload::Headers(headers_payload)) + } + Some(PUSH_PROMISE_FRAME_TYPE) => { + let mut push_promise = + PushPromise::new(self.push_frame_id.ok_or(FieldMissing)?, part); + if len.is_some() { + push_promise.set_instruction(ins_buf) + }; + Frame::new(PUSH_PROMISE_FRAME_TYPE, Payload::PushPromise(push_promise)) + } + _ => unreachable!(), }; - frame_decoder + frames.push(FrameKind::Complete(Box::from(frame))); + self.clear_frame(); + self.set_decode_state(DecodeState::FrameType); + if remaining == 0 { + self.clear_buffer(); + Ok(DecodePartRes::ReturnOuter) + } else { + Ok(DecodePartRes::Continue) + } + } + + fn decode_headers_payload( + &mut self, + frames: &mut Frames, + qpack_decoder: &mut QpackDecoder, + id: u64, + ) -> Result { + let mut reader = ReadableBytes::from(&self.buffer.as_slice()[self.offset..]); + let payload_len = self.remain_payload_len() as usize; + if reader.cap() < payload_len { + if let FieldDecodeState::Blocked = qpack_decoder + .decode_repr(reader.remaining(), id) + .map_err(DecodeError::QpackError)? + { + frames.push(FrameKind::Blocked); + } else { + frames.push(FrameKind::Partial); + } + self.subtract_payload_len(reader.cap())?; + self.clear_buffer(); + Ok(DecodePartRes::ReturnOuter) + } else { + match qpack_decoder + .decode_repr(reader.slice(payload_len)?, id) + .map_err(DecodeError::QpackError)? + { + FieldDecodeState::Blocked => { + frames.push(FrameKind::Blocked); + self.subtract_payload_len(payload_len)?; + self.offset += payload_len; + Ok(DecodePartRes::ReturnOuter) + } + FieldDecodeState::Decoded => { + self.offset += payload_len; + self.get_qpack_decoded_header(frames, qpack_decoder, id, reader.cap()) + } + } + } + } + + fn decode_variable_payload(&mut self, frames: &mut Frames) -> Result { + let mut reader = ReadableBytes::from(&self.buffer.as_slice()[self.offset..]); + match reader.get_varint() { + Ok(id) => { + match self.frame_type { + Some(frame::GOAWAY_FRAME_TYPE) => { + let frame = + Frame::new(frame::GOAWAY_FRAME_TYPE, Payload::Goaway(GoAway::new(id))); + frames.push(FrameKind::Complete(Box::from(frame))); + } + Some(frame::MAX_PUSH_ID_FRAME_TYPE) => { + let frame = Frame::new( + frame::MAX_PUSH_ID_FRAME_TYPE, + Payload::MaxPushId(MaxPushId::new(id)), + ); + frames.push(FrameKind::Complete(Box::from(frame))); + } + Some(frame::CANCEL_PUSH_FRAME_TYPE) => { + let frame = Frame::new( + frame::CANCEL_PUSH_FRAME_TYPE, + Payload::CancelPush(CancelPush::new(id)), + ); + frames.push(FrameKind::Complete(Box::from(frame))); + } + _ => return Err(DecodeError::UnexpectedFrame(self.frame_type.unwrap()).into()), + } + self.offset += reader.index(); + let remaining = reader.cap(); + self.clear_frame(); + self.set_decode_state(DecodeState::FrameType); + if remaining == 0 { + self.clear_buffer(); + Ok(DecodePartRes::ReturnOuter) + } else { + Ok(DecodePartRes::Continue) + } + } + Err(_) => { + frames.push(FrameKind::Partial); + Ok(DecodePartRes::ReturnOuter) + } + } } - /// User call `decode_header` to decode headers. - pub(crate) fn decode_header(&mut self, headers_payload: &[u8]) -> Result<(), H3Error_QPACK> { - self.qpack_decoder.decode_repr(&headers_payload) + fn decode_settings_payload(&mut self, frames: &mut Frames) -> Result { + let mut reader = ReadableBytes::from(&self.buffer.as_slice()[self.offset..]); + let payload_len = self.remain_payload_len(); + if (reader.cap() as u64) < payload_len { + frames.push(FrameKind::Partial); + return Ok(DecodePartRes::ReturnOuter); + } + let mut settings = Settings::default(); + + let mut addition = Vec::new(); + + while (reader.index() as u64) < payload_len { + let key = match reader.get_varint() { + Ok(id) => id, + Err(_) => return Err(FrameSizeError(payload_len).into()), + }; + let value = match reader.get_varint() { + Ok(val) => val, + Err(_) => return Err(FrameSizeError(payload_len).into()), + }; + + match key { + frame::SETTING_QPACK_MAX_TABLE_CAPACITY => { + settings.set_qpack_max_table_capacity(value); + } + frame::SETTING_ENABLE_CONNECT_PROTOCOL => { + settings.set_connect_protocol_enabled(value) + } + frame::SETTING_H3_DATAGRAM => settings.set_h3_datagram(value), + frame::SETTING_MAX_FIELD_SECTION_SIZE => settings.set_max_field_section_size(value), + frame::SETTING_QPACK_BLOCKED_STREAMS => settings.set_qpack_block_stream(value), + 0x0 | 0x2 | 0x3 | 0x4 | 0x5 => return Err(UnsupportedSetting(key).into()), + _ => addition.push((key, value)), + } + } + + if !addition.is_empty() { + settings.set_additional(addition); + } + let frame = Frame::new(SETTINGS_FRAME_TYPE, Payload::Settings(settings)); + frames.push(FrameKind::Complete(Box::from(frame))); + let remaining = reader.cap(); + self.clear_frame(); + self.set_decode_state(DecodeState::FrameType); + if remaining == 0 { + self.clear_buffer(); + Ok(DecodePartRes::ReturnOuter) + } else { + Ok(DecodePartRes::Continue) + } } - /// User call `finish_decode_header` to finish decode a stream. - pub(crate) fn finish_decode_header(&mut self) { - let results = self.qpack_decoder.finish( - self.stream_id, - &mut self.qpack_encoder_buffer[self.remaining_qpack_payload..], - ); - if let Ok((header, cur_)) = results { - self.headers = header; - if let Some(cur) = cur_ { - self.remaining_qpack_payload += cur; + fn decode_push_payload(&mut self, frames: &mut Frames) -> Result { + let mut reader = ReadableBytes::from(&self.buffer.as_slice()[self.offset..]); + match reader.get_varint() { + Ok(id) => { + self.push_frame_id = Some(id); + self.subtract_payload_len(reader.index())?; + self.set_decode_state(DecodeState::HeadersPayload); + Ok(DecodePartRes::Continue) + } + Err(_) => { + frames.push(FrameKind::Partial); + Ok(DecodePartRes::ReturnOuter) } } } - /// User call `decode_qpack_ins` to decode peer's qpack_encoder_stream. - pub(crate) fn decode_qpack_ins(&mut self, qpack_ins: &[u8]) -> Result<(), H3Error_QPACK> { - self.qpack_decoder.decode_ins(&qpack_ins) + fn decode_unknown_payload(&mut self, frames: &mut Frames) -> Result { + let mut reader = ReadableBytes::from(&self.buffer.as_slice()[self.offset..]); + // Note: None is impossible for `stream.payload_len`. + let payload_len = self.remain_payload_len() as usize; + if reader.cap() < payload_len { + let remaining = reader.cap(); + self.clear_buffer(); + self.subtract_payload_len(remaining)?; + frames.push(FrameKind::Partial); + Ok(DecodePartRes::ReturnOuter) + } else { + reader.slice(payload_len)?; + // Reader will renew by stream.offset, so don't need to reset. + self.offset += payload_len; + let remaining = reader.cap(); + self.clear_frame(); + self.set_decode_state(DecodeState::FrameType); + if remaining == 0 { + self.clear_buffer(); + Ok(DecodePartRes::ReturnOuter) + } else { + Ok(DecodePartRes::Continue) + } + } } } -#[cfg(test)] -mod ut_headers_decode { - use crate::h3::decoder::FrameDecoder; - use crate::h3::qpack::table::DynamicTable; - use crate::h3::qpack::QpackDecoder; - use crate::test_util::decode; - #[test] - fn literal_field_line_with_name_reference() { - println!("run literal_field_line_with_name_reference"); - let mut table = DynamicTable::with_empty(); - table.update_size(1024); - let mut f_decoder = FrameDecoder::new(16383, &mut table, 0); - f_decoder.decode_header(&decode("0000510b2f696e6465782e68746d6c").unwrap()); - f_decoder.finish_decode_header(); - let (pseudo, map) = f_decoder.headers.into_parts(); - assert_eq!(pseudo.path, Some(String::from("/index.html"))); - println!("passed"); + +fn decode_type(integer: u64) -> StreamType { + match integer { + 0x0 => StreamType::Control, + 0x1 => StreamType::Push, + 0x2 => StreamType::QpackEncoder, + 0x3 => StreamType::QpackDecoder, + _ => StreamType::Unknown, } } diff --git a/ylong_http/src/h3/encoder.rs b/ylong_http/src/h3/encoder.rs index 45213f2..f44dbbf 100644 --- a/ylong_http/src/h3/encoder.rs +++ b/ylong_http/src/h3/encoder.rs @@ -11,505 +11,637 @@ // See the License for the specific language governing permissions and // limitations under the License. -use octets::{Octets, OctetsMut}; - -use crate::h3::frame_new::{Headers, Payload}; -// use crate::h3::octets::WriteVarint; +use std::collections::hash_map::Entry; +use std::collections::HashMap; + +use ylong_runtime::iter::parallel::ParSplit; + +use crate::h3::error::CommonError::{BufferTooShort, InternalError}; +use crate::h3::error::DecodeError::UnexpectedFrame; +use crate::h3::error::EncodeError::{ + NoCurrentFrame, RepeatSetFrame, UnknownFrameType, WrongTypeFrame, +}; +use crate::h3::error::{DecodeError, EncodeError, H3Error}; +use crate::h3::frame::{Headers, Payload}; +use crate::h3::octets::{ReadableBytes, WritableBytes}; +use crate::h3::qpack::encoder::EncodeMessage; +use crate::h3::qpack::error::{ErrorCode, QpackError}; use crate::h3::qpack::table::DynamicTable; use crate::h3::qpack::{DecoderInst, QpackEncoder}; -use crate::h3::{frame_new, Frame}; +use crate::h3::EncodeError::TooManySettings; +use crate::h3::{frame, is_bidirectional, octets, Frame}; #[derive(PartialEq, Debug)] enum FrameEncoderState { // The initial state for the frame encoder. Idle, FrameComplete, - PayloadComplete, // Header Frame EncodingHeadersFrame, EncodingHeadersPayload, // Data Frame EncodingDataFrame, - EncodingDataPaylaod, + EncodingDataPayload, // CancelPush Frame EncodingCancelPushFrame, - EncodingCancelPushPayload, // Settings Frame EncodingSettingsFrame, EncodingSettingsPayload, - // PushPromise Frame - EncodingPushPromiseFrame, - EncodingPushPromisePayload, // Goaway Frame EncodingGoawayFrame, - EncodingGoawayPayload, // MaxPushId Frame EncodingMaxPushIdFrame, - EncodingMaxPushIdPayload, } -pub struct FrameEncoder<'a> { - qpack_encoder: QpackEncoder<'a>, - stream_id: usize, - // other frames +struct EncodedH3Stream { + stream_id: u64, + headers_message: Option, current_frame: Option, state: FrameEncoderState, - encoded_bytes: usize, - buf_offset: usize, payload_offset: usize, } -impl<'a> FrameEncoder<'a> { - /// Create a FrameEncoder - /// note: user should give the qpack's dynamic table, which is shared with - /// Decoder. - pub(crate) fn new( - table: &'a mut DynamicTable, - qpack_all_post: bool, - qpack_drain_index: usize, - stream_id: usize, - ) -> Self { - Self { - qpack_encoder: QpackEncoder::new(table, stream_id, qpack_all_post, qpack_drain_index), - stream_id, - current_frame: None, - state: FrameEncoderState::Idle, - encoded_bytes: 0, - buf_offset: 0, - payload_offset: 0, - } +pub(crate) struct EncHeaders { + message: EncodeMessage, + repr_offset: usize, + inst_offset: usize, +} + +/// HTTP3 frame encoder, which serializes a Frame into a byte stream in the +/// http3 protocol. +/// +/// # Examples +/// +/// ``` +/// use ylong_http::h3::{Data, Frame, FrameEncoder, Payload}; +/// +/// let mut encoder = FrameEncoder::default(); +/// let data_frame = Frame::new( +/// 0, +/// Payload::Data(Data::new(vec![b'h', b'e', b'l', b'l', b'o'])), +/// ); +/// encoder.set_frame(0, data_frame).unwrap(); +/// let mut res = [0u8; 1024]; +/// let mut ins = [0u8; 1024]; +/// let message = encoder.encode(0, &mut res, &mut ins).unwrap(); +/// ``` +#[derive(Default)] +pub struct FrameEncoder { + qpack_encoder: QpackEncoder, + streams: HashMap, +} + +pub struct EncodedSize { + frame_size: usize, + inst_size: usize, +} + +impl FrameEncoder { + /// Sets the maximum dynamic table capacity, + /// which must not exceed the SETTINGS_QPACK_MAX_TABLE_CAPACITY sent by the + /// peer Decoder. + pub fn set_max_table_capacity(&mut self, max_cap: usize) -> Result<(), H3Error> { + self.qpack_encoder + .set_max_table_capacity(max_cap) + .map_err(|e| EncodeError::QpackError(e).into()) + } + + /// Sets the SETTINGS_QPACK_BLOCKED_STREAMS sent by the peer Decoder. + pub fn set_max_blocked_stream_size(&mut self, max_blocked: usize) { + self.qpack_encoder.set_max_blocked_stream_size(max_blocked) } /// Sets the current frame to be encoded by the `FrameEncoder`. The state of /// the encoder is updated based on the payload type of the frame. - pub fn set_frame(&mut self, frame: Frame) { - self.current_frame = Some(frame); - // Reset the encoded bytes counter - self.encoded_bytes = 0; + pub fn set_frame(&mut self, stream_id: u64, frame: Frame) -> Result<(), H3Error> { + let stream = self + .streams + .entry(stream_id) + .or_insert(EncodedH3Stream::new(stream_id)); + + match stream.state { + FrameEncoderState::Idle | FrameEncoderState::FrameComplete => {} + _ => return Err(RepeatSetFrame.into()), + } + stream.current_frame = Some(frame); // set frame state - match &self.current_frame { - Some(frame) => match frame.frame_type() { - &frame_new::HEADERS_FRAME_TYPE_ID => { + if let Some(ref frame) = stream.current_frame { + match *frame.frame_type() { + frame::HEADERS_FRAME_TYPE => { if let Payload::Headers(h) = frame.payload() { - // todo! header压缩 self.qpack_encoder.set_parts(h.get_part()); - // complete output in one go. - let payload_size = - self.qpack_encoder.encode(&mut self.header_payload_buffer); - self.remaining_header_payload = payload_size; - self.state = FrameEncoderState::EncodingHeadersFrame; + stream.state = FrameEncoderState::EncodingHeadersFrame; } } - &frame_new::DATA_FRAME_TYPE_ID => self.state = FrameEncoderState::EncodingDataFrame, - &frame_new::CANCEL_PUSH_FRAME_TYPE_ID => { - self.state = FrameEncoderState::EncodingCancelPushFrame - } - &frame_new::SETTINGS_FRAME_TYPE_ID => { - self.state = FrameEncoderState::EncodingSettingsFrame + frame::DATA_FRAME_TYPE => stream.state = FrameEncoderState::EncodingDataFrame, + frame::CANCEL_PUSH_FRAME_TYPE => { + stream.state = FrameEncoderState::EncodingCancelPushFrame } - &frame_new::PUSH_PROMISE_FRAME_TYPE_ID => { - self.state = FrameEncoderState::EncodingPushPromiseFrame + frame::SETTINGS_FRAME_TYPE => { + stream.state = FrameEncoderState::EncodingSettingsFrame } - &frame_new::GOAWAY_FRAME_TYPE_ID => { - self.state = FrameEncoderState::EncodingGoawayFrame + frame::GOAWAY_FRAME_TYPE => stream.state = FrameEncoderState::EncodingGoawayFrame, + frame::MAX_PUSH_ID_FRAME_TYPE => { + stream.state = FrameEncoderState::EncodingMaxPushIdFrame } - &frame_new::MAX_PUSH_FRAME_TYPE_ID => { - self.state = FrameEncoderState::EncodingMaxPushIdFrame + _ => { + return Err(UnknownFrameType.into()); } - _ => {} - }, - None => self.state = FrameEncoderState::Idle, + }; } + Ok(()) } - fn encode_payload(&self, buf: &mut OctetsMut, data: &[u8], start: usize) -> usize { - let data_len = data.len(); - let remaining_data_bytes = data_len.saturating_sub(start); - let bytes_to_write = remaining_data_bytes.min(buf.len()); - // use unwrap, because data len must be smaller than buf len - buf.put_bytes(&data[start..start + bytes_to_write]).unwrap(); - bytes_to_write + /// Decode the instructions sent by the peer decoder stream. + pub fn decode_remote_inst(&mut self, buf: &[u8]) -> Result<(), H3Error> { + self.qpack_encoder + .decode_ins(buf) + .map_err(|e| H3Error::Decode(DecodeError::QpackError(e))) } - fn encode_frame(&self, frame_ref: Option<&Frame>, buf: &mut [u8]) -> Result { - if let Some(frame) = frame_ref { - let mut octet_buf = OctetsMut::with_slice(buf); - octet_buf.put_varint(frame.frame_type().clone())?; - octet_buf.put_varint(frame.frame_len().clone())?; - let size = octet_buf.off(); - Ok(size) - } else { - Err(FrameEncoderErr::NoCurrentFrame) + /// Serializes a Frame into a byte stream in the http3 protocol. + /// + /// # Examples + /// + /// ``` + /// use ylong_http::h3::{Data, Frame, FrameEncoder, Payload}; + /// + /// let mut encoder = FrameEncoder::default(); + /// let data_frame = Frame::new( + /// 0, + /// Payload::Data(Data::new(vec![b'h', b'e', b'l', b'l', b'o'])), + /// ); + /// encoder.set_frame(0, data_frame).unwrap(); + /// let mut res = [0u8; 1024]; + /// let mut ins = [0u8; 1024]; + /// let message = encoder.encode(0, &mut res, &mut ins).unwrap(); + /// ``` + pub fn encode( + &mut self, + stream_id: u64, + frame_buf: &mut [u8], + inst_buf: &mut [u8], + ) -> Result<(usize, usize), H3Error> { + if frame_buf.len() < 1024 { + return Err(BufferTooShort.into()); } - } - - pub fn encode(&mut self, buf: &mut [u8]) -> Result { - let mut written_bytes = 0; + let (mut frame_bytes, inst_bytes) = (0, 0); - while written_bytes < buf.len() { - match self.state { - FrameEncoderState::Idle - | FrameEncoderState::PayloadComplete - | FrameEncoderState::FrameComplete => { + let stream = self.streams.get_mut(&stream_id).ok_or(InternalError)?; + while frame_bytes < frame_buf.len() { + match stream.state { + FrameEncoderState::Idle | FrameEncoderState::FrameComplete => { break; } FrameEncoderState::EncodingHeadersFrame => { - match self.encode_frame(self.current_frame.as_ref(), buf) { - Ok(size) => { - self.encoded_bytes += size; - self.state = FrameEncoderState::EncodingHeadersPayload; - } - Err(_) => Err(FrameEncoderErr::NoCurrentFrame), - } + let frame_type = stream.encode_frame_type(frame_buf)?; + frame_bytes += frame_type; + stream.state = FrameEncoderState::EncodingHeadersPayload; } FrameEncoderState::EncodingHeadersPayload => { - if let Some(frame) = self.current_frame.as_ref() { - if let Payload::Headers(h) = frame.payload() { - let buf_remain = &mut buf[self.encoded_bytes..]; - let size_remain = buf_remain.len(); - let mut octet_buf = OctetsMut::with_slice(buf_remain); - if (h.get_headers().len() - self.payload_offset) < size_remain { - self.encode_payload( - &mut octet_buf, - h.get_headers(), - self.payload_offset, - ); - self.payload_offset = 0; - self.state = FrameEncoderState::PayloadComplete; - } else { - let writen_bytes = self.encode_payload( - &mut octet_buf, - h.get_headers(), - self.payload_offset, - ); - self.payload_offset += writen_bytes; - self.encoded_bytes += writen_bytes; - } - let size = octet_buf.off(); - Ok(size) - } else { - Err(FrameEncoderErr) - } - } else { - Err(FrameEncoderErr::NoCurrentFrame) - } + let (payload, inst) = stream.encode_headers_payload( + &mut self.qpack_encoder, + &mut frame_buf[frame_bytes..], + inst_buf, + )?; + return Ok((payload + frame_bytes, inst)); } FrameEncoderState::EncodingDataFrame => { - match self.encode_frame(self.current_frame.as_ref(), buf) { - Ok(size) => { - self.encoded_bytes += size; - self.state = FrameEncoderState::EncodingDataPaylaod; - } - Err(_) => Err(FrameEncoderErr::NoCurrentFrame), - } + let frame_type = stream.encode_frame_type(frame_buf)?; + frame_bytes += frame_type; + let len = stream.encode_data_len(&mut frame_buf[frame_bytes..])?; + frame_bytes += len; + stream.state = FrameEncoderState::EncodingDataPayload; } - FrameEncoderState::EncodingDataPaylaod => { - if let Some(frame) = self.current_frame.as_ref() { - if let Payload::Data(d) = frame.payload() { - let buf_remain = &mut buf[self.encoded_bytes..]; - let size_remain = buf_remain.len(); - let mut octet_buf = OctetsMut::with_slice(buf_remain); - if (d.data().len() - self.payload_offset) < size_remain { - self.encode_payload(&mut octet_buf, d.data(), self.payload_offset); - self.payload_offset = 0; - self.state = FrameEncoderState::PayloadComplete; - } else { - let writen_bytes = self.encode_payload( - &mut octet_buf, - d.data(), - self.payload_offset, - ); - self.payload_offset += writen_bytes; - self.encoded_bytes += writen_bytes; - } - let size = octet_buf.off(); - Ok(size) - } else { - Err(FrameEncoderErr) - } - } else { - Err(FrameEncoderErr::NoCurrentFrame) - } + FrameEncoderState::EncodingDataPayload => { + return stream + .encode_data_payload(&mut frame_buf[frame_bytes..]) + .map(|size| (size + frame_bytes, 0)); } FrameEncoderState::EncodingCancelPushFrame => { - match self.encode_frame(self.current_frame.as_ref(), buf) { - Ok(size) => { - self.encoded_bytes += size; - self.state = FrameEncoderState::EncodingCancelPushPayload; - } - Err(_) => Err(FrameEncoderErr::NoCurrentFrame), - } - } - FrameEncoderState::EncodingCancelPushPayload => { - if let Some(frame) = self.current_frame.as_ref() { - if let Payload::CancelPush(cp) = frame.payload() { - let buf_remain = &mut buf[self.encoded_bytes..]; - let mut octet_buf = OctetsMut::with_slice(buf_remain); - octet_buf.put_varint(cp.get_push_id().clone())?; - self.state = FrameEncoderState::PayloadComplete; - let size = octet_buf.off(); - Ok(size) - } else { - Err(FrameEncoderErr) - } - } else { - Err(FrameEncoderErr::NoCurrentFrame) - } + let frame_type = stream.encode_frame_type(frame_buf)?; + frame_bytes += frame_type; + return stream + .encode_cancel_push(&mut frame_buf[frame_bytes..]) + .map(|size| (size + frame_bytes, 0)); } FrameEncoderState::EncodingSettingsFrame => { - match self.encode_frame(self.current_frame.as_ref(), buf) { - Ok(size) => { - self.encoded_bytes += size; - self.state = FrameEncoderState::EncodingSettingsPayload; - } - Err(_) => Err(FrameEncoderErr::NoCurrentFrame), - } + let frame_type = stream.encode_frame_type(frame_buf)?; + frame_bytes += frame_type; + let len = stream.encode_settings_len(&mut frame_buf[frame_bytes..])?; + frame_bytes += len; + stream.state = FrameEncoderState::EncodingSettingsPayload; } FrameEncoderState::EncodingSettingsPayload => { - if let Some(frame) = self.current_frame.as_ref() { - if let Payload::Settings(s) = frame.payload() { - let buf_remain = &mut buf[self.encoded_bytes..]; - let mut octet_buf = OctetsMut::with_slice(buf_remain); - if let Some(val) = s.get_max_fied_section_size() { - octet_buf.put_varint(frame_new::SETTINGS_MAX_FIELD_SECTION_SIZE)?; - octet_buf.put_varint(val.clone())?; - } - - if let Some(val) = s.get_qpack_max_table_capacity() { - octet_buf - .put_varint(frame_new::SETTINGS_QPACK_MAX_TABLE_CAPACITY)?; - octet_buf.put_varint(val.clone())?; - } - - if let Some(val) = s.get_qpack_block_stream() { - octet_buf.put_varint(frame_new::SETTINGS_QPACK_BLOCKED_STREAMS)?; - octet_buf.put_varint(val.clone())?; - } - - if let Some(val) = s.get_connect_protocol_enabled() { - octet_buf - .put_varint(frame_new::SETTINGS_ENABLE_CONNECT_PROTOCOL)?; - octet_buf.put_varint(val.clone())?; - } - - if let Some(val) = s.get_h3_datagram() { - octet_buf.put_varint(frame_new::SETTINGS_H3_DATAGRAM_00)?; - octet_buf.put_varint(val.clone())?; - octet_buf.put_varint(frame_new::SETTINGS_H3_DATAGRAM)?; - octet_buf.put_varint(val.clone())?; - } - - if octet_buf.off() == 0 { - Err(FrameEncoderErr::NoCurrentFrame) - } - self.encoded_bytes += octet_buf.off(); - self.state = FrameEncoderState::PayloadComplete; - Ok(octet_buf.off()) - } - } - } - - FrameEncoderState::EncodingPushPromiseFrame => { - match self.encode_frame(self.current_frame.as_ref(), buf) { - Ok(size) => { - self.encoded_bytes += size; - self.state = FrameEncoderState::EncodingPushPromisePayload; - } - Err(_) => Err(FrameEncoderErr::NoCurrentFrame), - } + return stream + .encode_settings_payload(&mut frame_buf[frame_bytes..]) + .map(|size| (size + frame_bytes, 0)); } - // todo! - FrameEncoderState::EncodingPushPromisePayload => {} FrameEncoderState::EncodingGoawayFrame => { - match self.encode_frame(self.current_frame.as_ref(), buf) { - Ok(size) => { - self.encoded_bytes += size; - self.state = FrameEncoderState::EncodingGoawayPayload; - } - Err(_) => Err(FrameEncoderErr::NoCurrentFrame), - } - } - FrameEncoderState::EncodingGoawayPayload => { - if let Some(frame) = self.current_frame.as_ref() { - if let Payload::Goaway(g) = frame.payload() { - let buf_remain = &mut buf[self.encoded_bytes..]; - let mut octet_buf = OctetsMut::with_slice(buf_remain); - octet_buf.put_varint(g.get_id().clone())?; - self.state = FrameEncoderState::PayloadComplete; - let size = octet_buf.off(); - Ok(size) - } else { - Err(FrameEncoderErr) - } - } else { - Err(FrameEncoderErr::NoCurrentFrame) - } + let frame_type = stream.encode_frame_type(frame_buf)?; + frame_bytes += frame_type; + return stream + .encode_goaway(&mut frame_buf[frame_bytes..]) + .map(|size| (size + frame_bytes, 0)); } FrameEncoderState::EncodingMaxPushIdFrame => { - match self.encode_frame(self.current_frame.as_ref(), buf) { - Ok(size) => { - self.encoded_bytes += size; - self.state = FrameEncoderState::EncodingMaxPushIdPayload; - } - Err(_) => Err(FrameEncoderErr::NoCurrentFrame), - } + let frame_type = stream.encode_frame_type(frame_buf)?; + frame_bytes += frame_type; + return stream + .encode_max_push_id(&mut frame_buf[frame_bytes..]) + .map(|size| (size + frame_bytes, 0)); + } + } + } + Ok((frame_bytes, inst_bytes)) + } + + /// Cleans the stream information when the stream normally ends. + pub fn finish_stream(&mut self, id: u64) -> Result<(), H3Error> { + if is_bidirectional(id) { + self.qpack_encoder + .finish_stream(id) + .map_err(|e| H3Error::Encode(e.into()))?; + } + self.streams.remove(&id); + Ok(()) + } +} + +impl EncHeaders { + pub(crate) fn new(message: EncodeMessage) -> Self { + Self { + message, + repr_offset: 0, + inst_offset: 0, + } + } + pub(crate) fn message(&self) -> &EncodeMessage { + &self.message + } + + pub(crate) fn repr_offset(&self) -> usize { + self.repr_offset + } + + pub(crate) fn inst_offset(&self) -> usize { + self.inst_offset + } + + pub(crate) fn repr_offset_inc(&mut self, increment: usize) { + self.repr_offset += increment + } + + pub(crate) fn inst_offset_inc(&mut self, increment: usize) { + self.inst_offset += increment + } + + pub(crate) fn remaining_repr(&self) -> usize { + self.message.fields().len() - self.repr_offset + } + + pub(crate) fn remaining_inst(&self) -> usize { + self.message.inst().len() - self.inst_offset + } +} + +impl EncodedSize { + pub fn frame_size(&self) -> usize { + self.frame_size + } + + pub fn inst_size(&self) -> usize { + self.inst_size + } + + pub fn new(frame_size: usize, inst_size: usize) -> Self { + Self { + frame_size, + inst_size, + } + } +} + +impl EncodedH3Stream { + pub(crate) fn new(stream_id: u64) -> Self { + Self { + stream_id, + headers_message: None, + current_frame: None, + state: FrameEncoderState::Idle, + payload_offset: 0, + } + } + + fn encode_settings_payload(&mut self, frame_buf: &mut [u8]) -> Result { + let mut written = 0; + if let Some(frame) = self.current_frame.as_ref() { + if let Payload::Settings(settings) = frame.payload() { + // Ensure that it can be completed in a one go. + self.state = FrameEncoderState::FrameComplete; + + if let Some(v) = settings.max_fied_section_size() { + written += encode_var_integer( + frame::SETTING_MAX_FIELD_SECTION_SIZE, + &mut frame_buf[written..], + )?; + written += encode_var_integer(v, &mut frame_buf[written..])?; + } + if let Some(v) = settings.connect_protocol_enabled() { + written += encode_var_integer( + frame::SETTING_ENABLE_CONNECT_PROTOCOL, + &mut frame_buf[written..], + )?; + written += encode_var_integer(v, &mut frame_buf[written..])?; + } + if let Some(v) = settings.qpack_max_table_capacity() { + written += encode_var_integer( + frame::SETTING_QPACK_MAX_TABLE_CAPACITY, + &mut frame_buf[written..], + )?; + written += encode_var_integer(v, &mut frame_buf[written..])?; + } + if let Some(v) = settings.qpack_block_stream() { + written += encode_var_integer( + frame::SETTING_QPACK_BLOCKED_STREAMS, + &mut frame_buf[written..], + )?; + written += encode_var_integer(v, &mut frame_buf[written..])?; + } + if let Some(v) = settings.h3_datagram() { + written += + encode_var_integer(frame::SETTING_H3_DATAGRAM, &mut frame_buf[written..])?; + written += encode_var_integer(v, &mut frame_buf[written..])?; } - FrameEncoderState::EncodingMaxPushIdPayload => { - if let Some(frame) = self.current_frame.as_ref() { - if let Payload::MaxPushId(max) = frame.payload() { - let buf_remain = &mut buf[self.encoded_bytes..]; - let mut octet_buf = OctetsMut::with_slice(buf_remain); - octet_buf.put_varint(max.get_id().clone())?; - self.state = FrameEncoderState::PayloadComplete; - let size = octet_buf.off(); - Ok(size) - } else { - Err(FrameEncoderErr) - } - } else { - Err(FrameEncoderErr::NoCurrentFrame) + if let Some(v) = settings.additional() { + for (key, value) in v.iter() { + written += encode_var_integer(*key, &mut frame_buf[written..])?; + written += encode_var_integer(*value, &mut frame_buf[written..])?; } } - _ => {} + Ok(written) + } else { + Err(WrongTypeFrame.into()) } + } else { + Err(NoCurrentFrame.into()) } - Ok(written_bytes) } - /// Encoder can modify size of the dynamic table, initial size of the table - /// is 0. the size can also be updated from decoder - pub(crate) fn update_dyn_size(&mut self, max_size: usize) { - let cur_qpack = self.qpack_encoder.set_capacity( - max_size, - &mut self.qpack_encoder_buffer[self.remaining_qpack_payload..], - ); - self.remaining_qpack_payload += cur_qpack; + fn encode_headers_repr_and_inst( + &mut self, + frame_buf: &mut [u8], + inst_buf: &mut [u8], + ) -> Result<(usize, usize), H3Error> { + let repr_writen = self.encode_headers_repr(frame_buf); + let inst_writen = self.encode_qpack_inst(inst_buf); + if let Some(ref message) = self.headers_message { + if message.remaining_repr() == 0 && message.remaining_inst() == 0 { + self.state = FrameEncoderState::FrameComplete; + self.headers_message = None; + } + } + Ok((repr_writen, inst_writen)) } - /// User call `encode_header` to encode a header. - pub fn encode_header(&mut self, headers: &Headers) { - self.qpack_encoder.set_parts(headers.get_parts()); - let (cur_qpack, cur_header, _) = self.qpack_encoder.encode( - &mut self.qpack_encoder_buffer[self.remaining_qpack_payload..], - &mut self.header_payload_buffer[self.remaining_header_payload..], - ); - self.remaining_header_payload += cur_header; - self.remaining_qpack_payload += cur_qpack; + fn encode_headers_with_qpack( + &mut self, + qpack_encoder: &mut QpackEncoder, + frame_buf: &mut [u8], + inst_buf: &mut [u8], + ) -> Result<(usize, usize), H3Error> { + if self.headers_message.is_none() { + let message = qpack_encoder.encode(self.stream_id); + let payload_size = message.fields().len(); + self.headers_message = Some(EncHeaders::new(message)); + // encode headers frame payload length + let encoded_payload_size = encode_var_integer(payload_size as u64, frame_buf)?; + let (frame_size, inst_size) = self + .encode_headers_repr_and_inst(&mut frame_buf[encoded_payload_size..], inst_buf)?; + Ok((frame_size + encoded_payload_size, inst_size)) + } else { + self.encode_headers_repr_and_inst(frame_buf, inst_buf) + } } - /// User must call `finish_encode_header` to end a batch of `encode_header`, - /// so as to add prefix to this stream. - pub fn finish_encode_header(&mut self) { - let (cur_qpack, cur_header, mut prefix) = self.qpack_encoder.encode( - &mut self.qpack_encoder_buffer[self.remaining_qpack_payload..], - &mut self.header_payload_buffer[self.remaining_header_payload..], - ); - self.remaining_header_payload += cur_header; - self.remaining_qpack_payload += cur_qpack; - if let Some((prefix_buf, cur_prefix)) = prefix { - self.header_payload_buffer - .copy_within(0..self.remaining_header_payload, cur_prefix); - self.header_payload_buffer[..cur_prefix].copy_from_slice(&prefix_buf[..cur_prefix]); - self.remaining_header_payload += cur_prefix; + fn encode_headers_repr(&mut self, frame_buf: &mut [u8]) -> usize { + if let Some(mut enc_headers) = self.headers_message.take() { + let mut written = 0; + let repr_size = enc_headers.remaining_repr(); + let cap = frame_buf.len(); + if cap >= repr_size { + frame_buf[..repr_size].copy_from_slice( + &enc_headers.message().fields().as_slice()[enc_headers.repr_offset()..], + ); + written += repr_size; + // finish encode headers + enc_headers.repr_offset_inc(repr_size); + } else { + frame_buf.copy_from_slice( + &enc_headers.message().fields().as_slice() + [enc_headers.repr_offset()..enc_headers.repr_offset() + cap], + ); + written += cap; + enc_headers.repr_offset_inc(cap); + } + self.headers_message = Some(enc_headers); + return written; } + 0 } - /// User call `decode_ins` to decode peer's qpack_decoder_stream. - pub fn decode_ins(&mut self, buf: &[u8]) { - match self.qpack_encoder.decode_ins(buf) { - Ok(Some(DecoderInst::StreamCancel)) => { - // todo: cancel this stream. + fn encode_qpack_inst(&mut self, inst_buf: &mut [u8]) -> usize { + if let Some(mut enc_headers) = self.headers_message.take() { + let mut written = 0; + let inst_size = enc_headers.remaining_inst(); + let cap = inst_buf.len(); + if cap >= inst_size { + inst_buf[..inst_size].copy_from_slice( + &enc_headers.message().inst().as_slice()[enc_headers.inst_offset()..], + ); + written += inst_size; + // finish encode headers + enc_headers.inst_offset_inc(inst_size); + } else { + inst_buf.copy_from_slice( + &enc_headers.message().inst().as_slice() + [enc_headers.inst_offset()..enc_headers.inst_offset() + cap], + ); + written += cap; + enc_headers.inst_offset_inc(cap); } - _ => {} + self.headers_message = Some(enc_headers); + return written; } + 0 } -} -#[cfg(test)] -mod ut_headers_encode { - use crate::h3::encoder::FrameEncoder; - use crate::h3::frame::Headers; - use crate::h3::parts::Parts; - use crate::h3::qpack::table::{DynamicTable, Field}; - use crate::test_util::decode; - - /// `s_res`: header stream after encoding by QPACK. - /// `q_res`: QPACK stream after encoding by QPACK. - #[test] - /// The encoder sends an encoded field section containing a literal - /// representation of a field with a static name reference. - fn literal_field_line_with_name_reference() { - let mut table = DynamicTable::with_empty(); - let mut f_encoder = FrameEncoder::new(&mut table, false, 0, 0); - - let s_res = decode("0000510b2f696e6465782e68746d6c").unwrap(); - let headers = [(Field::Path, String::from("/index.html"))]; - - for (field, value) in headers.iter() { - let mut part = Parts::new(); - println!("encoding: HEADER: {:?} , VALUE: {:?}", field, value); - part.update(field.clone(), value.clone()); - let header = Headers::new(part.clone()); - f_encoder.encode_header(&header); + fn encode_frame_type(&self, frame_buf: &mut [u8]) -> Result { + if let Some(frame) = self.current_frame.as_ref() { + encode_var_integer(*frame.frame_type(), frame_buf) + } else { + Err(NoCurrentFrame.into()) } - f_encoder.finish_encode_header(); - println!( - "header_payload_buffer: {:?}", - f_encoder.header_payload_buffer[..f_encoder.remaining_header_payload].to_vec() - ); - assert_eq!( - s_res, - f_encoder.header_payload_buffer[..f_encoder.remaining_header_payload].to_vec() - ); } - #[test] - /// The encoder sets the dynamic table capacity, inserts a header with a - /// dynamic name reference, then sends a potentially blocking, encoded - /// field section referencing this new entry. The decoder acknowledges - /// processing the encoded field section, which implicitly acknowledges - /// all dynamic table insertions up to the Required Insert Count. - fn dynamic_table() { - let mut table = DynamicTable::with_empty(); - - let mut f_encoder = FrameEncoder::new(&mut table, true, 0, 0); - - f_encoder.update_dyn_size(220); - let s_res = decode("03811011").unwrap(); - let q_res = - decode("3fbd01c00f7777772e6578616d706c652e636f6dc10c2f73616d706c652f70617468").unwrap(); - let headers = [ - (Field::Authority, String::from("www.example.com")), - (Field::Path, String::from("/sample/path")), - ]; - for (field, value) in headers.iter() { - println!("encoding: HEADER: {:?} , VALUE: {:?}", field, value); - let mut part = Parts::new(); - part.update(field.clone(), value.clone()); - let header = Headers::new(part.clone()); - f_encoder.encode_header(&header); + fn encode_data_len(&self, frame_buf: &mut [u8]) -> Result { + if let Some(frame) = self.current_frame.as_ref() { + if let Payload::Data(d) = frame.payload() { + encode_var_integer((*d).data().len() as u64, frame_buf) + } else { + Err(WrongTypeFrame.into()) + } + } else { + Err(NoCurrentFrame.into()) + } + } + + fn encode_cancel_push(&mut self, frame_buf: &mut [u8]) -> Result { + let frame = self + .current_frame + .as_ref() + .ok_or(H3Error::Encode(NoCurrentFrame))?; + if let Payload::CancelPush(push) = frame.payload() { + let size = + encode_var_integer(octets::varint_len(*push.get_push_id()) as u64, frame_buf)?; + // Ensure that it can be completed in a one go. + self.state = FrameEncoderState::FrameComplete; + encode_var_integer(*push.get_push_id(), &mut frame_buf[size..]) + } else { + Err(WrongTypeFrame.into()) } - f_encoder.finish_encode_header(); - println!( - "header_payload_buffer: {:?}", - f_encoder.header_payload_buffer[..f_encoder.remaining_header_payload].to_vec() - ); - println!( - "qpack_encoder_buffer: {:?}", - f_encoder.qpack_encoder_buffer[..f_encoder.remaining_qpack_payload].to_vec() - ); - assert_eq!( - s_res, - f_encoder.header_payload_buffer[..f_encoder.remaining_header_payload].to_vec() - ); - assert_eq!( - q_res, - f_encoder.qpack_encoder_buffer[..f_encoder.remaining_qpack_payload].to_vec() - ); } + + fn encode_goaway(&mut self, frame_buf: &mut [u8]) -> Result { + if let Some(frame) = self.current_frame.as_ref() { + if let Payload::Goaway(away) = frame.payload() { + let size = + encode_var_integer(octets::varint_len(*away.get_id()) as u64, frame_buf)?; + // Ensure that it can be completed in a one go. + self.state = FrameEncoderState::FrameComplete; + + encode_var_integer(*away.get_id(), &mut frame_buf[size..]) + } else { + Err(WrongTypeFrame.into()) + } + } else { + Err(NoCurrentFrame.into()) + } + } + + fn encode_max_push_id(&mut self, frame_buf: &mut [u8]) -> Result { + if let Some(frame) = self.current_frame.as_ref() { + if let Payload::MaxPushId(push) = frame.payload() { + let size = + encode_var_integer(octets::varint_len(*push.get_id()) as u64, frame_buf)?; + // Ensure that it can be completed in a one go. + self.state = FrameEncoderState::FrameComplete; + + encode_var_integer(*push.get_id(), &mut frame_buf[size..]) + } else { + Err(WrongTypeFrame.into()) + } + } else { + Err(NoCurrentFrame.into()) + } + } + + fn encode_settings_len(&self, frame_buf: &mut [u8]) -> Result { + let mut written = 0; + if let Some(frame) = self.current_frame.as_ref() { + if let Payload::Settings(settings) = frame.payload() { + if let Some(v) = settings.max_fied_section_size() { + written += octets::varint_len(frame::SETTING_MAX_FIELD_SECTION_SIZE); + written += octets::varint_len(v); + } + if let Some(v) = settings.connect_protocol_enabled() { + written += octets::varint_len(frame::SETTING_ENABLE_CONNECT_PROTOCOL); + written += octets::varint_len(v); + } + if let Some(v) = settings.qpack_max_table_capacity() { + written += octets::varint_len(frame::SETTING_QPACK_MAX_TABLE_CAPACITY); + written += octets::varint_len(v); + } + if let Some(v) = settings.qpack_block_stream() { + written += octets::varint_len(frame::SETTING_QPACK_BLOCKED_STREAMS); + written += octets::varint_len(v); + } + if let Some(v) = settings.h3_datagram() { + written += octets::varint_len(frame::SETTING_H3_DATAGRAM); + written += octets::varint_len(v); + } + if let Some(v) = settings.additional() { + // ensure frame buf is enough capacity. + if v.len() > 50 { + return Err(TooManySettings.into()); + } + for (key, value) in v.iter() { + written += octets::varint_len(*key); + written += octets::varint_len(*value); + } + } + let var_written = encode_var_integer(written as u64, frame_buf)?; + Ok(var_written) + } else { + Err(WrongTypeFrame.into()) + } + } else { + Err(NoCurrentFrame.into()) + } + } + + fn encode_headers_payload( + &mut self, + qpack_encoder: &mut QpackEncoder, + frame_buf: &mut [u8], + inst_buf: &mut [u8], + ) -> Result<(usize, usize), H3Error> { + if let Some(frame) = self.current_frame.as_ref() { + if let Payload::Headers(_headers) = frame.payload() { + self.encode_headers_with_qpack(qpack_encoder, frame_buf, inst_buf) + } else { + Err(WrongTypeFrame.into()) + } + } else { + Err(NoCurrentFrame.into()) + } + } + + fn encode_data_payload(&mut self, frame_buf: &mut [u8]) -> Result { + if let Some(frame) = self.current_frame.as_ref() { + if let Payload::Data(d) = frame.payload() { + let data = d.data(); + let buf_size = frame_buf.len(); + let remaining = data.len() - self.payload_offset; + if buf_size >= remaining { + frame_buf[..remaining].copy_from_slice(&data.as_slice()[self.payload_offset..]); + self.payload_offset = 0; + self.state = FrameEncoderState::FrameComplete; + Ok(remaining) + } else { + frame_buf.copy_from_slice( + &data.as_slice()[self.payload_offset..self.payload_offset + buf_size], + ); + self.payload_offset += buf_size; + Ok(buf_size) + } + } else { + Err(WrongTypeFrame.into()) + } + } else { + Err(NoCurrentFrame.into()) + } + } +} + +fn encode_var_integer(src: u64, frame_buf: &mut [u8]) -> Result { + let mut writable_buf = WritableBytes::from(frame_buf); + writable_buf.write_varint(src)?; + let size = writable_buf.index(); + Ok(size) } diff --git a/ylong_http/src/h3/error.rs b/ylong_http/src/h3/error.rs index 0d03a7e..8efe573 100644 --- a/ylong_http/src/h3/error.rs +++ b/ylong_http/src/h3/error.rs @@ -10,3 +10,191 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + +use std::convert::Infallible; +use std::error::Error; +use std::fmt::{Debug, Display, Formatter}; + +use crate::h3::qpack::error::QpackError; + +/// HTTP3 errors. +#[derive(Debug, Eq, PartialEq, Clone)] +pub enum H3Error { + /// Serialization error. + Encode(EncodeError), + /// Deserialization error. + Decode(DecodeError), + /// Common error during serialization or deserialization. + Serialize(CommonError), + /// Connection level error. + Connection(H3ErrorCode), + /// Stream level error. + Stream(u64, H3ErrorCode), +} + +/// Error during serialization. +#[derive(Debug, Eq, PartialEq, Clone)] +pub enum EncodeError { + /// The set frame could not be found during serialization. + NoCurrentFrame, + /// The type of frame set does not match the serialized one. + WrongTypeFrame, + /// The previous frame has not been serialized. + RepeatSetFrame, + /// Sets a frame of unknown type. + UnknownFrameType, + /// Too many additional Settings are encoded. + TooManySettings, + /// qpack encoder encoding error. + QpackError(QpackError), +} + +/// Error during deserialization. +#[derive(Debug, Eq, PartialEq, Clone)] +pub enum DecodeError { + /// The frame type does not correspond to the stream type. + UnexpectedFrame(u64), + /// Qpack decoder decoding error. + QpackError(QpackError), + /// The payload length resolved is different from the actual data. + FrameSizeError(u64), + /// Http3 does not allow the type of setting. + UnsupportedSetting(u64), +} + +/// Errors during serialization and deserialization, +/// usually occur during variable interger serialization and deserialization. +#[derive(Debug, Eq, PartialEq, Clone)] +pub enum CommonError { + /// The buf used to store serialized data is too short. + BufferTooShort, + /// The field for the frame is missing. + FieldMissing, + /// Computation time overflow. + CalculateOverflow, + /// Internal error. + InternalError, +} + +/// Common http3 error codes defined in the rfc documentation. +/// Refers to [`iana`]. +/// +/// [`iana`]: https://www.iana.org/assignments/http3-parameters/http3-parameters.xhtml#http3-parameters-error-codes +#[derive(Debug, Eq, PartialEq, Copy, Clone)] +pub enum H3ErrorCode { + /// Datagram or Capsule Protocol parse error. + H3DatagramError = 0x33, + /// No error. + H3NoError = 0x100, + /// General protocol error. + H3GeneralProtocolError = 0x101, + /// Internal error. + H3InternalError = 0x102, + /// Stream creation error. + H3StreamCreationError = 0x103, + /// Critical stream was closed. + H3ClosedCriticalStream = 0x104, + /// Frame not permitted in the current state. + H3FrameUnexpected = 0x105, + /// Frame violated layout or size rules. + H3FrameError = 0x106, + /// Peer generating excessive load. + H3ExcessiveLoad = 0x107, + /// An identifier was used incorrectly. + H3IdError = 0x108, + /// SETTINGS frame contained invalid values. + H3SettingsError = 0x109, + /// No SETTINGS frame received. + H3MissingSettings = 0x10A, + /// Request not processed. + H3RequestRejected = 0x10B, + /// Data no longer needed. + H3RequestCancelled = 0x10C, + /// Stream terminated early. + H3RequestIncomplete = 0x10D, + /// Malformed message. + H3MessageError = 0x10E, + /// TCP reset or error on CONNECT request. + H3ConnectError = 0x10F, + /// Retry over HTTP/1.1. + H3VersionFallback = 0x110, + /// Decoding of a field section failed. + QPACKDecompressionFailed = 0x200, + /// Error on the encoder stream. + QPACKEncoderStreamError = 0x201, + /// Error on the decoder stream. + QPACKDecoderStreamError = 0x202, +} + +impl From for H3ErrorCode { + fn from(value: u64) -> Self { + match value { + 0x33 => H3ErrorCode::H3DatagramError, + 0x100 => H3ErrorCode::H3NoError, + 0x101 => H3ErrorCode::H3GeneralProtocolError, + 0x102 => H3ErrorCode::H3InternalError, + 0x103 => H3ErrorCode::H3StreamCreationError, + 0x104 => H3ErrorCode::H3ClosedCriticalStream, + 0x105 => H3ErrorCode::H3FrameUnexpected, + 0x106 => H3ErrorCode::H3FrameError, + 0x107 => H3ErrorCode::H3ExcessiveLoad, + 0x108 => H3ErrorCode::H3IdError, + 0x109 => H3ErrorCode::H3SettingsError, + 0x10A => H3ErrorCode::H3MissingSettings, + 0x10B => H3ErrorCode::H3RequestRejected, + 0x10C => H3ErrorCode::H3RequestCancelled, + 0x10D => H3ErrorCode::H3RequestIncomplete, + 0x10E => H3ErrorCode::H3MessageError, + 0x10F => H3ErrorCode::H3ConnectError, + 0x110 => H3ErrorCode::H3VersionFallback, + 0x200 => H3ErrorCode::QPACKDecompressionFailed, + 0x201 => H3ErrorCode::QPACKEncoderStreamError, + 0x202 => H3ErrorCode::QPACKDecoderStreamError, + _ => H3ErrorCode::H3GeneralProtocolError, + } + } +} + +impl From for DecodeError { + fn from(value: QpackError) -> Self { + DecodeError::QpackError(value) + } +} + +impl From for EncodeError { + fn from(value: QpackError) -> Self { + EncodeError::QpackError(value) + } +} + +impl From for H3Error { + fn from(value: EncodeError) -> Self { + H3Error::Encode(value) + } +} + +impl From for H3Error { + fn from(value: DecodeError) -> Self { + H3Error::Decode(value) + } +} + +impl From for H3Error { + fn from(value: CommonError) -> Self { + H3Error::Serialize(value) + } +} + +impl From for H3Error { + fn from(_value: Infallible) -> Self { + unreachable!() + } +} + +impl Display for H3Error { + fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { + Debug::fmt(self, f) + } +} + +impl Error for H3Error {} diff --git a/ylong_http/src/h3/frame.rs b/ylong_http/src/h3/frame.rs index 77a2bbb..84b6e2b 100644 --- a/ylong_http/src/h3/frame.rs +++ b/ylong_http/src/h3/frame.rs @@ -1,278 +1,307 @@ -// Copyright (c) 2023 Huawei Device Co., Ltd. -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use crate::h3::parts::Parts; - -pub const DATA_FRAME_TYPE_ID: u64 = 0x0; -pub const HEADERS_FRAME_TYPE_ID: u64 = 0x1; -pub const CANCEL_PUSH_FRAME_TYPE_ID: u64 = 0x3; -pub const SETTINGS_FRAME_TYPE_ID: u64 = 0x4; -pub const PUSH_PROMISE_FRAME_TYPE_ID: u64 = 0x5; -pub const GOAWAY_FRAME_TYPE_ID: u64 = 0x7; -pub const MAX_PUSH_FRAME_TYPE_ID: u64 = 0xD; -pub const PRIORITY_UPDATE_FRAME_REQUEST_TYPE_ID: u64 = 0xF0700; -pub const PRIORITY_UPDATE_FRAME_PUSH_TYPE_ID: u64 = 0xF0701; -pub const SETTINGS_QPACK_MAX_TABLE_CAPACITY: u64 = 0x1; -pub const SETTINGS_MAX_FIELD_SECTION_SIZE: u64 = 0x6; -pub const SETTINGS_QPACK_BLOCKED_STREAMS: u64 = 0x7; -pub const SETTINGS_ENABLE_CONNECT_PROTOCOL: u64 = 0x8; -pub const SETTINGS_H3_DATAGRAM_00: u64 = 0x276; -pub const SETTINGS_H3_DATAGRAM: u64 = 0x33; -// Permit between 16 maximally-encoded and 128 minimally-encoded SETTINGS. -const MAX_SETTINGS_PAYLOAD_SIZE: usize = 256; - -#[derive(Clone)] -pub struct Frame { - ty: u64, - len: u64, - payload: Payload, -} - -#[derive(Clone)] -pub enum Payload { - /// HEADERS frame payload. - Headers(Headers), - /// DATA frame payload. - Data(Data), - /// SETTINGS frame payload. - Settings(Settings), - /// CancelPush frame payload. - CancelPush(CancelPush), - /// PushPromise frame payload. - PushPromise(PushPromise), - /// GOAWAY frame payload. - Goaway(GoAway), - /// MaxPushId frame payload. - MaxPushId(MaxPushId), - /// Unknown frame payload. - Unknown(Unknown), -} - -#[derive(Clone)] -pub struct Headers { - headers: Vec, - parts: Parts, -} - -#[derive(Clone)] -pub struct Data { - data: Vec, -} - -#[derive(Clone)] -pub struct Settings { - max_field_section_size: Option, - qpack_max_table_capacity: Option, - qpack_blocked_streams: Option, - connect_protocol_enabled: Option, - h3_datagram: Option, - raw: Option>, -} - -#[derive(Clone)] -pub struct CancelPush { - push_id: u64, -} - -#[derive(Clone)] -pub struct PushPromise { - push_id: u64, - headers: Vec, -} - -#[derive(Clone)] -pub struct GoAway { - id: u64, -} - -pub struct MaxPushId { - push_id: u64, -} - -pub struct Unknown { - raw_type: u64, - len: u64, -} - -impl Frame { - pub fn new(ty: u64, len: u64, payload: Payload) -> Self { - Frame { ty, len, payload } - } - - pub fn frame_type(&self) -> &u64 { - &self.ty - } - - pub fn frame_len(&self) -> &u64 { - &self.len - } - - pub fn payload(&self) -> &Payload { - &self.payload - } - - pub(crate) fn payload_mut(&mut self) -> &mut Payload { - &mut self.payload - } -} - -// settings结构体相当于quiche中setting结构体 -impl Settings { - /// Creates a new Settings instance containing the provided settings. - pub fn new() -> Self { - Settings { - max_field_section_size: None, - qpack_max_table_capacity: None, - qpack_blocked_streams: None, - connect_protocol_enabled: None, - h3_datagram: None, - raw: None, - } - } - - /// SETTINGS_HEADER_TABLE_SIZE (0x01) setting. - pub fn max_fied_section_size(mut self, size: u64) -> Self { - self.max_field_section_size = Some(size); - self - } - - /// SETTINGS_ENABLE_PUSH (0x02) setting. - pub fn qpack_max_table_capacity(mut self, size: u64) -> Self { - self.qpack_max_table_capacity = Some(size); - self - } - - /// SETTINGS_MAX_FRAME_SIZE (0x05) setting. - pub fn qpack_block_stream(mut self, size: u64) -> Self { - self.qpack_blocked_streams = Some(size); - self - } - - /// SETTINGS_MAX_HEADER_LIST_SIZE (0x06) setting. - pub fn connect_protocol_enabled(mut self, size: u64) -> Self { - self.connect_protocol_enabled = Some(size); - self - } - - pub fn h3_datagram(mut self, size: u64) -> Self { - self.h3_datagram = Some(size); - self - } - - /// SETTINGS_HEADER_TABLE_SIZE (0x01) setting. - pub fn get_max_fied_section_size(&self) -> &Option { - &self.max_field_section_size - } - - /// SETTINGS_ENABLE_PUSH (0x02) setting. - pub fn get_qpack_max_table_capacity(&self) -> &Option { - &self.qpack_max_table_capacity - } - - /// SETTINGS_MAX_FRAME_SIZE (0x05) setting. - pub fn get_qpack_block_stream(&self) -> &Option { - &self.qpack_blocked_streams - } - - /// SETTINGS_MAX_HEADER_LIST_SIZE (0x06) setting. - pub fn get_connect_protocol_enabled(&self) -> &Option { - &self.connect_protocol_enabled - } - - pub fn get_h3_datagram(&self) -> &Option { - &self.h3_datagram - } -} - -impl Data { - /// Creates a new Data instance containing the provided data. - pub fn new(data: Vec) -> Self { - Data { data } - } - - /// Return the `Vec` that contains the data payload. - pub fn data(&self) -> &Vec { - &self.data - } -} - -impl CancelPush { - /// Creates a new CancelPush instance from the provided Parts. - pub fn new(id: u64) -> Self { - CancelPush { push_id: id } - } - - pub fn get_push_id(&self) -> &u64 { - &self.push_id - } -} - -impl Headers { - /// Creates a new Headers instance from the provided Parts. - pub fn new(parts: Parts) -> Self { - Headers { - headers: vec![0; 16383], - parts, - } - } - - /// Returns pseudo headers and other headers - pub fn get_headers(&self) -> &Vec { - &self.headers - } - - pub fn get_part(&self) -> Parts { - self.parts.clone() - } -} - -impl PushPromise { - /// Creates a new PushPromise instance from the provided Parts. - pub fn new(push_id: u64, header: Vec) -> Self { - PushPromise { - push_id, - headers: header, - } - } - - pub fn get_push_id(&self) -> u64 { - self.push_id - } - - /// Returns a copy of the internal parts of the Headers. - pub(crate) fn get_headers(&self) -> &Vec { - &self.headers.clone() - } -} - -impl GoAway { - /// Creates a new GoAway instance from the provided Parts. - pub fn new(id: u64) -> Self { - GoAway { id } - } - - pub fn get_id(&self) -> &u64 { - &self.id - } -} - -impl MaxPushId { - /// Creates a new MaxPushId instance from the provided Parts. - pub fn new(push_id: u64) -> Self { - MaxPushId { push_id } - } - - pub fn get_id(&self) -> &u64 { - &self.push_id - } -} +// Copyright (c) 2023 Huawei Device Co., Ltd. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use crate::h3::parts::Parts; + +/// Data Frame type code. +pub const DATA_FRAME_TYPE: u64 = 0x0; +/// HEADERS Frame type code. +pub const HEADERS_FRAME_TYPE: u64 = 0x1; +/// CANCEL_PUSH Frame type code. +pub const CANCEL_PUSH_FRAME_TYPE: u64 = 0x3; +/// SETTINGS Frame type code. +pub const SETTINGS_FRAME_TYPE: u64 = 0x4; +/// PUSH_PROMISE Frame type code. +pub const PUSH_PROMISE_FRAME_TYPE: u64 = 0x5; +/// GOAWAY Frame type code. +pub const GOAWAY_FRAME_TYPE: u64 = 0x7; +/// MAX_PUSH_ID Frame type code. +pub const MAX_PUSH_ID_FRAME_TYPE: u64 = 0xD; +/// SETTING_QPACK_MAX_TABLE_CAPACITY setting code. +pub const SETTING_QPACK_MAX_TABLE_CAPACITY: u64 = 0x1; +/// SETTING_MAX_FIELD_SECTION_SIZE setting code. +pub const SETTING_MAX_FIELD_SECTION_SIZE: u64 = 0x6; +/// SETTING_QPACK_BLOCKED_STREAMS setting code. +pub const SETTING_QPACK_BLOCKED_STREAMS: u64 = 0x7; +/// SETTING_ENABLE_CONNECT_PROTOCOL setting code. +pub const SETTING_ENABLE_CONNECT_PROTOCOL: u64 = 0x8; +/// SETTING_H3_DATAGRAM setting code. +pub const SETTING_H3_DATAGRAM: u64 = 0x33; +/// MAX_SETTING_PAYLOAD_SIZE setting code. +// Permit between 16 maximally-encoded and 128 minimally-encoded SETTINGS. +const MAX_SETTING_PAYLOAD_SIZE: usize = 256; + +/// Http3 frame definition. +#[derive(Clone, Debug)] +pub struct Frame { + ty: u64, + payload: Payload, +} + +/// Http3 frame payload. +#[derive(Clone, Debug)] +pub enum Payload { + /// HEADERS frame payload. + Headers(Headers), + /// DATA frame payload. + Data(Data), + /// SETTINGS frame payload. + Settings(Settings), + /// CancelPush frame payload. + CancelPush(CancelPush), + /// PushPromise frame payload. + PushPromise(PushPromise), + /// GOAWAY frame payload. + Goaway(GoAway), + /// MaxPushId frame payload. + MaxPushId(MaxPushId), + /// Unknown frame payload. + Unknown(Unknown), +} + +/// Http3 Headers frame payload, which also contains instructions to send when +/// decoding. +#[derive(Clone, Debug)] +pub struct Headers { + parts: Parts, + ins: Option>, +} + +/// Http3 Data frame payload, containing the body data. +#[derive(Clone, Debug)] +pub struct Data { + data: Vec, +} + +/// Http3 Settings frame payload. +#[derive(Clone, Default, Debug)] +pub struct Settings { + max_field_section_size: Option, + qpack_max_table_capacity: Option, + qpack_blocked_streams: Option, + connect_protocol_enabled: Option, + h3_datagram: Option, + additional: Option>, +} + +/// Http3 CancelPush frame payload. +#[derive(Clone, Debug)] +pub struct CancelPush { + push_id: u64, +} + +/// Http3 PushPromise frame payload. +#[derive(Clone, Debug)] +pub struct PushPromise { + push_id: u64, + parts: Parts, + ins: Option>, +} + +/// Http3 GoAway frame payload. +#[derive(Clone, Debug)] +pub struct GoAway { + id: u64, +} + +/// Http3 MaxPushId frame payload. +#[derive(Clone, Debug)] +pub struct MaxPushId { + push_id: u64, +} + +/// Http3 Unknown frame payload. +#[derive(Clone, Debug)] +pub struct Unknown { + raw_type: u64, + len: u64, +} + +impl Frame { + /// Constructs a Frame with type and payload. + pub fn new(ty: u64, payload: Payload) -> Self { + Frame { ty, payload } + } + + /// Gets frame type. + pub fn frame_type(&self) -> &u64 { + &self.ty + } + + /// Gets frame payload. + pub fn payload(&self) -> &Payload { + &self.payload + } + + /// Gets a mutable frame payload of current frame. + pub(crate) fn payload_mut(&mut self) -> &mut Payload { + &mut self.payload + } +} + +impl Settings { + /// Sets SETTINGS_HEADER_TABLE_SIZE (0x01) setting. + pub fn set_max_field_section_size(&mut self, size: u64) { + self.max_field_section_size = Some(size) + } + + /// Sets SETTINGS_ENABLE_PUSH (0x02) setting. + pub fn set_qpack_max_table_capacity(&mut self, size: u64) { + self.qpack_max_table_capacity = Some(size) + } + + /// Sets SETTINGS_MAX_FRAME_SIZE (0x05) setting. + pub fn set_qpack_block_stream(&mut self, size: u64) { + self.qpack_blocked_streams = Some(size) + } + + /// Sets SETTINGS_MAX_HEADER_LIST_SIZE (0x06) setting. + pub fn set_connect_protocol_enabled(&mut self, size: u64) { + self.connect_protocol_enabled = Some(size) + } + + /// Sets SETTINGS_H3_DATAGRAM setting. + pub fn set_h3_datagram(&mut self, size: u64) { + self.h3_datagram = Some(size); + } + + /// Sets additional settings. + pub fn set_additional(&mut self, addition: Vec<(u64, u64)>) { + self.additional = Some(addition) + } + + /// Gets SETTINGS_MAX_FIELD_SECTION_SIZE setting. + pub fn max_fied_section_size(&self) -> Option { + self.max_field_section_size + } + + /// Gets SETTINGS_QPACK_MAX_TABLE_CAPACITY setting. + pub fn qpack_max_table_capacity(&self) -> Option { + self.qpack_max_table_capacity + } + + /// Gets SETTINGS_QPACK_BLOCKED_STREAMS setting. + pub fn qpack_block_stream(&self) -> Option { + self.qpack_blocked_streams + } + + /// Gets SETTINGS_ENABLE_CONNECT_PROTOCOL setting. + pub fn connect_protocol_enabled(&self) -> Option { + self.connect_protocol_enabled + } + + /// Gets SETTINGS_H3_DATAGRAM setting. + pub fn h3_datagram(&self) -> Option { + self.h3_datagram + } + + /// Gets additional settings. + pub fn additional(&self) -> &Option> { + &self.additional + } +} + +impl Data { + /// Creates a new Data instance containing the provided data. + pub fn new(data: Vec) -> Self { + Data { data } + } + + /// Return the `Vec` that contains the data payload. + pub fn data(&self) -> &Vec { + &self.data + } +} + +impl CancelPush { + /// Creates a new CancelPush instance from the provided Parts. + pub fn new(id: u64) -> Self { + CancelPush { push_id: id } + } + + /// Gets push id of CancelPush payload. + pub fn get_push_id(&self) -> &u64 { + &self.push_id + } +} + +impl Headers { + /// Creates a new Headers instance from the provided Parts. + pub fn new(parts: Parts) -> Self { + Headers { parts, ins: None } + } + + /// Gets the instructions generated by qpack decoder after decoding headers + /// frame. + pub fn get_instruction(&self) -> &Option> { + &self.ins + } + + /// Gets headers part of Headers frame payload. + pub fn get_part(&self) -> Parts { + self.parts.clone() + } + + pub(crate) fn set_instruction(&mut self, buf: Vec) { + self.ins = Some(buf) + } +} + +impl PushPromise { + /// Creates a new PushPromise instance from the provided Parts. + pub fn new(push_id: u64, parts: Parts) -> Self { + PushPromise { + push_id, + parts, + ins: None, + } + } + + /// Gets push id of PushPromise payload. + pub fn get_push_id(&self) -> u64 { + self.push_id + } + + /// Returns a copy of the internal parts of the Headers. + pub(crate) fn get_parts(&self) -> &Parts { + &self.parts + } + + pub(crate) fn set_instruction(&mut self, buf: Vec) { + self.ins = Some(buf) + } +} + +impl GoAway { + /// Creates a new GoAway instance from the provided Parts. + pub fn new(id: u64) -> Self { + GoAway { id } + } + + /// Gets go away stream id. + pub fn get_id(&self) -> &u64 { + &self.id + } +} + +impl MaxPushId { + /// Creates a new MaxPushId instance from the provided Parts. + pub fn new(push_id: u64) -> Self { + MaxPushId { push_id } + } + + /// Gets allowed max push stream id. + pub fn get_id(&self) -> &u64 { + &self.push_id + } +} diff --git a/ylong_http/src/h3/mod.rs b/ylong_http/src/h3/mod.rs index 80d794a..ba24510 100644 --- a/ylong_http/src/h3/mod.rs +++ b/ylong_http/src/h3/mod.rs @@ -13,14 +13,30 @@ // TODO: `HTTP/3` Module. -mod connection; mod decoder; mod encoder; mod error; mod frame; -pub mod parts; -pub mod pseudo; -pub mod qpack; +mod octets; +mod parts; +mod pseudo; +mod qpack; // mod octets; mod stream; -pub use frame::Frame; +pub use decoder::FrameDecoder; +pub use encoder::FrameEncoder; +pub use error::{DecodeError, EncodeError, H3Error, H3ErrorCode}; +pub use frame::{ + Data, Frame, Headers, Payload, Settings, DATA_FRAME_TYPE, HEADERS_FRAME_TYPE, + SETTINGS_FRAME_TYPE, +}; +pub use parts::Parts; +pub use pseudo::PseudoHeaders; +pub use stream::{ + FrameKind, Frames, StreamMessage, CONTROL_STREAM_TYPE, QPACK_DECODER_STREAM_TYPE, + QPACK_ENCODER_STREAM_TYPE, +}; + +pub(crate) fn is_bidirectional(id: u64) -> bool { + (id & 0x02) == 0 +} diff --git a/ylong_http/src/h3/octets.rs b/ylong_http/src/h3/octets.rs index 6595219..c60f13c 100644 --- a/ylong_http/src/h3/octets.rs +++ b/ylong_http/src/h3/octets.rs @@ -1,303 +1,224 @@ -// Copyright (c) 2023 Huawei Device Co., Ltd. -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::convert::TryFrom; -use std::io::Read; - -pub type Result = std::result::Result; - -// pub trait TransferData { -// fn peek_u(&self, ty: &str, len: usize) -> &str; -// fn get_u(&self, ty: &str, len: usize) -> &str; -// fn put_u(&self, ty: &str, value: usize,len: usize) -> &mut [u8]; -// } -// -// impl TransferData for Octets { -// fn peek_u(&self, ty: &str, len: usize) -> &str { -// todo!() -// } -// -// fn get_u(&self, ty: &str, len: usize) -> &str { -// todo!() -// } -// -// fn put_u(&self, ty: &str, value: usize, len: usize) -> &mut [u8] { -// todo!() -// } -// } -// buf fragment splited by out offset -#[derive(Debug, PartialEq, Eq)] -pub struct ReadVarint<'a> { - buf: &'a [u8], -} - -impl<'a> ReadVarint<'a> { - pub fn new(buf: &'a [u8]) -> Self { - ReadVarint { buf } - } - - pub fn into_u8(&mut self) -> Result { - const len: usize = 1; - - if self.buf.len() < len { - return Err(BufferTooShortError); - } - let bytes: [u8; len] = <[u8; len]>::try_from(self.buf[..len].as_ref()).unwrap(); - if cfg!(target_endian = "big") { - let res: u8 = u8::from_be_bytes(bytes); - Ok(res) - } else { - let res: u8 = u8::from_le_bytes(bytes); - Ok(res) - } - } - - pub fn into_u16(&mut self) -> Result { - const len: usize = 2; - - if self.buf.len() < len { - return Err(BufferTooShortError); - } - let bytes: [u8; len] = <[u8; len]>::try_from(self.buf[..len].as_ref()).unwrap(); - if cfg!(target_endian = "big") { - let res: u16 = u16::from_be_bytes(bytes); - Ok(res) - } else { - let res: u16 = u16::from_le_bytes(bytes); - Ok(res) - } - } - - pub fn into_u32(&mut self) -> Result { - const len: usize = 4; - - if self.buf.len() < len { - return Err(BufferTooShortError); - } - let bytes: [u8; len] = <[u8; len]>::try_from(self.buf[..len].as_ref()).unwrap(); - if cfg!(target_endian = "big") { - let res: u32 = u32::from_be_bytes(bytes); - Ok(res) - } else { - let res: u32 = u32::from_le_bytes(bytes); - Ok(res) - } - } - - pub fn into_u64(&mut self) -> Result { - const len: usize = 8; - - if self.buf.len() < len { - return Err(BufferTooShortError); - } - let bytes: [u8; len] = <[u8; len]>::try_from(self.buf[..len].as_ref()).unwrap(); - if cfg!(target_endian = "big") { - let res: u64 = u64::from_be_bytes(bytes); - Ok(res) - } else { - let res: u64 = u64::from_le_bytes(bytes); - Ok(res) - } - } - - /// Reads an unsigned variable-length integer in network byte-order from - /// the current offset and advances the buffer. - pub fn get_varint(&mut self) -> Result { - let first = self.into_u8()?; - - let len = varint_parse_len(first); - - if len > self.cap() { - return Err(BufferTooShortError); - } - - let out = match len { - 1 => u64::from(self.into_u8()?), - - 2 => u64::from(self.into_u16()? & 0x3fff), - - 4 => u64::from(self.into_u32()? & 0x3fffffff), - - 8 => self.into_u64()? & 0x3fffffffffffffff, - - _ => unreachable!(), - }; - - Ok(out) - } - - /// Returns the remaining capacity in the buffer. - pub fn cap(&self) -> usize { - self.buf.len() - self.off - } -} - -// Component encoding status. -enum TokenStatus { - // The current component is completely encoded. - Complete(T), - // The current component is partially encoded. - Partial(E), -} - -type TokenResult = Result>; - -struct WriteData<'a> { - src: &'a [u8], - src_idx: &'a mut usize, - dst: &'a mut [u8], -} - -impl<'a> WriteData<'a> { - fn new(src: &'a [u8], src_idx: &'a mut usize, dst: &'a mut [u8]) -> Self { - WriteData { src, src_idx, dst } - } - - fn write(&mut self) -> TokenResult { - let src_idx = *self.src_idx; - let input_len = self.src.len() - src_idx; - let output_len = self.dst.len(); - let num = (&self.src[src_idx..]).read(self.dst).unwrap(); - if output_len >= input_len { - return Ok(TokenStatus::Complete(num)); - } - *self.src_idx += num; - Ok(TokenStatus::Partial(num)) - } -} - -pub struct WriteVarint<'a> { - src: &'a [u8], - src_idx: &'a mut usize, - dst: &'a mut [u8], -} - -impl<'a> WriteVarint<'a> { - // pub fn new(buf: &'a mut [u8]) -> Self { - // WriteVarint { buf } - // } - pub fn new(src: &'a [u8], src_idx: &'a mut usize, dst: &'a mut [u8]) -> Self { - // src需要从value转码过来 - WriteVarint { src, src_idx, dst } - } - - /// Writes an unsigned 8-bit integer at the current offset and advances - /// the buffer. - pub fn write_u8(&mut self, value: u8) -> Result { - const len: usize = 1; - // buf长度不够问题返回err,再由外层处理 - if self.buf.len() != len { - return Err(BufferTooShortError); - } - - let bytes: [u8; len] = value.to_be_bytes(); - self.buf.copy_from_slice(bytes.as_slice()); - Ok(len) - } - - pub fn write_u16(&mut self, value: u16) -> Result { - const len: usize = 2; - // buf长度不够问题返回err,再由外层处理 - if self.buf.len() != len { - return Err(BufferTooShortError); - } - - let bytes: [u8; len] = value.to_be_bytes(); - self.buf.copy_from_slice(bytes.as_slice()); - Ok(len) - } - - pub fn write_u32(&mut self, value: u32) -> Result { - const len: usize = 4; - // buf长度不够问题返回err,再由外层处理 - if self.buf.len() != len { - return Err(BufferTooShortError); - } - - let bytes: [u8; len] = value.to_be_bytes(); - self.buf.copy_from_slice(bytes.as_slice()); - Ok(len) - } - - pub fn write_u64(&mut self, value: u64) -> Result { - const len: usize = 8; - // buf长度不够问题返回err,再由外层处理 - if self.buf.len() != len { - return Err(BufferTooShortError); - } - - let bytes: [u8; len] = value.to_be_bytes(); - self.buf.copy_from_slice(bytes.as_slice()); - Ok(len) - } - - /// Writes an unsigned variable-length integer in network byte-order at the - /// current offset and advances the buffer. - pub fn write_varint(&mut self, value: u64) -> Result { - self.write_varint_with_len(value, varint_len(value)) - } - - pub fn write_varint_with_len(&mut self, value: u64, len: usize) -> Result { - if self.cap() < len { - return Err(BufferTooShortError); - } - - let res = match len { - 1 => self.write_u8(value as u8)?, - - 2 => { - let size = self.write_u16(value as u16)?; - *self.buf[0] |= 0x40; - size - } - - 4 => { - let size = self.write_u32(value as u32)?; - *self.buf[0] |= 0x80; - size - } - - 8 => { - let size = self.write_u64(value)?; - *self.buf[0] |= 0xc0; - size - } - - _ => panic!("value is too large for varint"), - }; - - Ok(res) - } -} - -/// Returns how many bytes it would take to encode `v` as a variable-length -/// integer. -pub const fn varint_len(v: u64) -> usize { - match v { - 0..=63 => 1, - 64..=16383 => 2, - 16384..=1_073_741_823 => 4, - 1_073_741_824..=4_611_686_018_427_387_903 => 8, - _ => {unreachable!()} - } -} - -/// Returns how long the variable-length integer is, given its first byte. -pub const fn varint_parse_len(byte: u8) -> usize { - let byte = byte >> 6; - if byte <= 3 { - 1 << byte - } else { - unreachable!() - } -} +// Copyright (c) 2023 Huawei Device Co., Ltd. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::convert::TryFrom; +use std::io::Read; + +use ylong_runtime::iter::parallel::ParSplit; + +use crate::h3::error::CommonError::BufferTooShort; +use crate::h3::error::{EncodeError, H3Error}; + +pub type Result = std::result::Result; + +macro_rules! peek_bytes { + ($buf: expr, $ty: ty, $len: expr) => {{ + if $buf.len() < $len { + return Err(H3Error::Serialize(BufferTooShort)); + } + let bytes: [u8; $len] = <[u8; $len]>::try_from($buf[..$len].as_ref()).unwrap(); + let res = <$ty>::from_be_bytes(bytes); + Ok(res) + }}; +} + +macro_rules! poll_bytes { + ($this: expr, $ty: ty, $len: expr) => {{ + let res = peek_bytes!($this.buf[$this.idx..], $ty, $len); + $this.idx += $len; + res + }}; +} + +#[derive(Debug, PartialEq, Eq)] +pub struct ReadableBytes<'a> { + buf: &'a [u8], + idx: usize, +} + +impl<'a> ReadableBytes<'a> { + pub fn from(buf: &'a [u8]) -> Self { + ReadableBytes { buf, idx: 0 } + } + + pub(crate) fn peek_u8(&mut self) -> Result { + peek_bytes!(self.buf[self.idx..], u8, 1) + } + + pub fn poll_u8(&mut self) -> Result { + poll_bytes!(self, u8, 1) + } + + pub fn poll_u16(&mut self) -> Result { + poll_bytes!(self, u16, 2) + } + + pub fn poll_u32(&mut self) -> Result { + poll_bytes!(self, u32, 4) + } + + pub fn poll_u64(&mut self) -> Result { + poll_bytes!(self, u64, 8) + } + + /// Reads an unsigned variable-length integer in network byte-order from + /// the current offset and advances the buffer. + pub fn get_varint(&mut self) -> Result { + let first = self.peek_u8()?; + let len = parse_varint_len(first); + if len > self.cap() { + return Err(BufferTooShort.into()); + } + let out = match len { + 1 => u64::from(self.poll_u8()?), + + 2 => u64::from(self.poll_u16()? & 0x3fff), + + 4 => u64::from(self.poll_u32()? & 0x3fffffff), + + 8 => self.poll_u64()? & 0x3fffffffffffffff, + + _ => unreachable!(), + }; + + Ok(out) + } + + /// Returns the remaining capacity in the buffer. + pub fn cap(&self) -> usize { + self.buf.len() - self.idx + } + + pub fn index(&self) -> usize { + self.idx + } + + pub fn remaining(&self) -> &[u8] { + &self.buf[self.idx..] + } + + pub fn slice(&mut self, length: usize) -> Result<&[u8]> { + if self.cap() < length { + Err(BufferTooShort.into()) + } else { + let curr = self.idx; + self.idx += length; + Ok(&self.buf[curr..self.idx]) + } + } +} + +macro_rules! write_bytes { + ($this: expr, $value: expr, $len: expr) => {{ + // buf长度不够问题返回err,再由外层处理 + if $this.remaining() < $len { + return Err(BufferTooShort.into()); + } + + let mut bytes: [u8; $len] = $value.to_be_bytes(); + match $len { + 1 => {} + 2 => bytes[0] |= 0x40, + 4 => bytes[0] |= 0x80, + 8 => bytes[0] |= 0xC0, + _ => unreachable!(), + } + $this.bytes[$this.idx..$this.idx + $len].copy_from_slice(bytes.as_slice()); + $this.idx_add($len); + Ok($len) + }}; +} + +pub struct WritableBytes<'a> { + bytes: &'a mut [u8], + idx: usize, +} + +impl<'a> WritableBytes<'a> { + pub fn from(bytes: &'a mut [u8]) -> WritableBytes<'a> { + Self { bytes, idx: 0 } + } + + pub fn remaining(&self) -> usize { + self.bytes.len() - self.idx + } + + pub fn index(&self) -> usize { + self.idx + } + + pub fn idx_add(&mut self, offset: usize) { + self.idx += offset + } + + /// Writes an unsigned 8-bit integer at the current offset and advances + /// the buffer. + pub fn write_u8(&mut self, value: u8) -> Result { + write_bytes!(self, value, 1) + } + + pub fn write_u16(&mut self, value: u16) -> Result { + write_bytes!(self, value, 2) + } + + pub fn write_u32(&mut self, value: u32) -> Result { + write_bytes!(self, value, 4) + } + + pub fn write_u64(&mut self, value: u64) -> Result { + write_bytes!(self, value, 8) + } + + /// Writes an unsigned variable-length integer in network byte-order at the + /// current offset and advances the buffer. + pub fn write_varint(&mut self, value: u64) -> Result { + self.write_varint_with_len(value, varint_len(value)) + } + + fn write_varint_with_len(&mut self, value: u64, len: usize) -> Result { + if self.remaining() < len { + return Err(BufferTooShort.into()); + } + match len { + 1 => self.write_u8(value as u8), + 2 => self.write_u16(value as u16), + 4 => self.write_u32(value as u32), + 8 => self.write_u64(value), + _ => panic!("value is too large for varint"), + } + } +} + +/// Returns how many bytes it would take to encode `v` as a variable-length +/// integer. +pub const fn varint_len(v: u64) -> usize { + match v { + 0..=63 => 1, + 64..=16383 => 2, + 16384..=1_073_741_823 => 4, + 1_073_741_824..=4_611_686_018_427_387_903 => 8, + _ => { + unreachable!() + } + } +} + +/// Returns how long the variable-length integer is, given its first byte. +pub const fn parse_varint_len(byte: u8) -> usize { + let pre = byte >> 6; + if pre <= 3 { + 1 << pre + } else { + unreachable!() + } +} diff --git a/ylong_http/src/h3/parts.rs b/ylong_http/src/h3/parts.rs index 9e2bc4c..3aeb9f7 100644 --- a/ylong_http/src/h3/parts.rs +++ b/ylong_http/src/h3/parts.rs @@ -11,14 +11,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#![rustfmt::skip] - -use crate::h3::qpack::table::Field; use crate::h3::pseudo::PseudoHeaders; +use crate::h3::qpack::table::NameField; use crate::headers::Headers; -/// HTTP2 HEADERS frame payload implementation. -#[derive(PartialEq, Eq, Clone)] +/// HTTP3 HEADERS frame payload implementation. +#[derive(PartialEq, Eq, Clone, Debug)] pub struct Parts { pub(crate) pseudo: PseudoHeaders, pub(crate) map: Headers, @@ -43,26 +41,30 @@ impl Parts { self.map = headers; } - pub(crate) fn is_empty(&self) -> bool { + /// Whether the Headers part is empty. + pub fn is_empty(&self) -> bool { self.pseudo.is_empty() && self.map.is_empty() } - pub(crate) fn update(&mut self, headers: Field, value: String) { + /// Updates a field in the Headers part. + pub fn update(&mut self, headers: NameField, value: String) { match headers { - Field::Authority => self.pseudo.set_authority(Some(value)), - Field::Method => self.pseudo.set_method(Some(value)), - Field::Path => self.pseudo.set_path(Some(value)), - Field::Scheme => self.pseudo.set_scheme(Some(value)), - Field::Status => self.pseudo.set_status(Some(value)), - Field::Other(header) => self.map.append(header.as_str(), value.as_str()).unwrap(), + NameField::Authority => self.pseudo.set_authority(Some(value)), + NameField::Method => self.pseudo.set_method(Some(value)), + NameField::Path => self.pseudo.set_path(Some(value)), + NameField::Scheme => self.pseudo.set_scheme(Some(value)), + NameField::Status => self.pseudo.set_status(Some(value)), + NameField::Other(header) => self.map.append(header.as_str(), value.as_str()).unwrap(), } } - pub(crate) fn parts(&self) -> (&PseudoHeaders, &Headers) { + /// Gets Headers part. + pub fn parts(&self) -> (&PseudoHeaders, &Headers) { (&self.pseudo, &self.map) } - pub(crate) fn into_parts(self) -> (PseudoHeaders, Headers) { + /// Takes ownership of parts and separate Headers and pseudo. + pub fn into_parts(self) -> (PseudoHeaders, Headers) { (self.pseudo, self.map) } } diff --git a/ylong_http/src/h3/pseudo.rs b/ylong_http/src/h3/pseudo.rs index ae7ba56..c38934f 100644 --- a/ylong_http/src/h3/pseudo.rs +++ b/ylong_http/src/h3/pseudo.rs @@ -18,7 +18,7 @@ /// # Note /// The current structure is not responsible for checking every value. // TODO: 考虑将 PseudoHeaders 拆分成 `RequestPseudo` 和 `ResponsePseudo`. -#[derive(Clone, PartialEq, Eq)] +#[derive(Clone, PartialEq, Eq, Debug)] pub struct PseudoHeaders { authority: Option, method: Option, diff --git a/ylong_http/src/h3/qpack/decoder.rs b/ylong_http/src/h3/qpack/decoder.rs index aef9f48..58aba3b 100644 --- a/ylong_http/src/h3/qpack/decoder.rs +++ b/ylong_http/src/h3/qpack/decoder.rs @@ -11,126 +11,111 @@ // See the License for the specific language governing permissions and // limitations under the License. -#![rustfmt::skip] +use std::collections::{HashMap, HashSet}; +use std::mem::take; + +use ylong_runtime::iter::parallel::ParSplit; use crate::h3::parts::Parts; +use crate::h3::qpack::decoder::FieldDecodeState::{Blocked, Decoded}; use crate::h3::qpack::error::ErrorCode::{DecompressionFailed, EncoderStreamError}; -use crate::h3::qpack::error::H3errorQpack; +use crate::h3::qpack::error::{ErrorCode, NotClassified, QpackError}; +use crate::h3::qpack::format::decoder::{ + EncInstDecoder, InstDecodeState, Name, ReprDecodeState, ReprDecoder, +}; +use crate::h3::qpack::integer::Integer; +use crate::h3::qpack::table::NameField::Path; +use crate::h3::qpack::table::{DynamicTable, NameField, TableSearcher}; use crate::h3::qpack::{ DeltaBase, EncoderInstPrefixBit, EncoderInstruction, MidBit, ReprPrefixBit, Representation, RequireInsertCount, }; -use std::mem::take; -use crate::h3::qpack::format::decoder::{ - EncInstDecoder, InstDecodeState, Name, ReprDecodeState, ReprDecoder, -}; -use crate::h3::qpack::integer::Integer; -use crate::h3::qpack::table::Field::Path; -use crate::h3::qpack::table::{DynamicTable, Field, TableSearcher}; - -/// An decoder is used to de-compress field in a compression format for efficiently representing -/// HTTP fields that is to be used in HTTP/3. This is a variation of HPACK compression that seeks -/// to reduce head-of-line blocking. -/// -/// # Examples(not run) -// ```no_run -// use crate::ylong_http::h3::qpack::table::{DynamicTable, Field}; -// use crate::ylong_http::h3::qpack::decoder::QpackDecoder; -// use crate::ylong_http::test_util::decode; -// const MAX_HEADER_LIST_SIZE: usize = 16 << 20; -// -// // Required content: -// let mut dynamic_table = DynamicTable::with_empty(); -// let mut decoder = QpackDecoder::new(MAX_HEADER_LIST_SIZE, &mut dynamic_table); -// -// //decode instruction -// // convert hex string to dec-array -// let mut inst = decode("3fbd01c00f7777772e6578616d706c652e636f6dc10c2f73616d706c652f70617468").unwrap().as_slice().to_vec(); -// decoder.decode_ins(&mut inst); -// -// //decode field section -// // convert hex string to dec-array -// let mut repr = decode("03811011").unwrap().as_slice().to_vec(); -// decoder.decode_repr(&mut repr); -// -// ``` +pub(crate) enum FieldDecodeState { + Blocked, + Decoded, +} pub(crate) struct FiledLines { parts: Parts, header_size: usize, } -pub struct QpackDecoder<'a> { - field_list_size: usize, - // max header list size - table: &'a mut DynamicTable, +pub(crate) struct ReprMessage { + require_insert_count: usize, + base: usize, // dynamic table repr_state: Option, - // field decode state - inst_state: Option, + remaining: Option>, // instruction decode state lines: FiledLines, - // field lines, which is used to store the decoded field lines - base: usize, - // RFC required, got from field section prefix - require_insert_count: usize, // RFC required, got from field section prefix } -impl<'a> QpackDecoder<'a> { - - /// create a new decoder - /// # Examples(not run) - /// -// ```no_run -// use crate::ylong_http::h3::qpack::table::{DynamicTable, Field}; -// use crate::ylong_http::h3::qpack::decoder::QpackDecoder; -// use crate::ylong_http::test_util::decode; -// const MAX_HEADER_LIST_SIZE: usize = 16 << 20; -// // Required content: -// let mut dynamic_table = DynamicTable::with_empty(); -// let mut decoder = QpackDecoder::new(MAX_HEADER_LIST_SIZE, &mut dynamic_table); -// ``` - pub fn new(field_list_size: usize, table: &'a mut DynamicTable) -> Self { +impl ReprMessage { + pub(crate) fn new() -> Self { Self { - field_list_size, - table, + require_insert_count: 0, + base: 0, repr_state: None, - inst_state: None, + remaining: None, lines: FiledLines { parts: Parts::new(), header_size: 0, }, - base: 0, - require_insert_count: 0, } } +} - /// Users can call `decode_ins` multiple times to decode decoder instructions. - /// # Examples(not run) -// ```no_run -// use crate::ylong_http::h3::qpack::table::{DynamicTable, Field}; -// use crate::ylong_http::h3::qpack::decoder::QpackDecoder; -// use crate::ylong_http::test_util::decode; -// const MAX_HEADER_LIST_SIZE: usize = 16 << 20; -// // Required content: -// let mut dynamic_table = DynamicTable::with_empty(); -// let mut decoder = QpackDecoder::new(MAX_HEADER_LIST_SIZE, &mut dynamic_table); -// //decode instruction -// // convert hex string to dec-array -// let mut inst = decode("3fbd01c00f7777772e6578616d706c652e636f6dc10c2f73616d706c652f70617468").unwrap().as_slice().to_vec(); -// decoder.decode_ins(&mut inst); -// ``` - pub fn decode_ins(&mut self, buf: &[u8]) -> Result<(), H3errorQpack> { +pub struct QpackDecoder { + // max header list size + table: DynamicTable, + // field decode state + inst_state: Option, + streams: HashMap, + blocked: HashMap, + max_blocked_streams: usize, + max_table_capacity: usize, + max_field_section_size: usize, +} + +impl QpackDecoder { + pub(crate) fn new(max_blocked_streams: usize, max_table_capacity: usize) -> Self { + Self { + table: DynamicTable::with_empty(), + inst_state: None, + streams: HashMap::new(), + blocked: HashMap::new(), + max_blocked_streams, + max_table_capacity, + max_field_section_size: (1 << 62) - 1, + } + } + + pub(crate) fn finish_stream(&mut self, id: u64) -> Result<(), QpackError> { + if self.blocked.contains_key(&id) { + Err(QpackError::ConnectionError(ErrorCode::DecoderStreamError)) + } else { + self.streams.remove(&id); + Ok(()) + } + } + + pub(crate) fn set_max_field_section_size(&mut self, size: usize) { + self.max_field_section_size = size; + } + + pub(crate) fn decode_ins(&mut self, buf: &[u8]) -> Result, QpackError> { let mut decoder = EncInstDecoder::new(); - let mut updater = Updater::new(self.table); + let mut updater = Updater::new(&mut self.table); let mut cnt = 0; - loop { + while cnt < buf.len() { match decoder.decode(&buf[cnt..], &mut self.inst_state)? { Some(inst) => match inst { (offset, EncoderInstruction::SetCap { capacity }) => { - println!("set cap"); cnt += offset; + if capacity > self.max_table_capacity { + return Err(QpackError::ConnectionError(DecompressionFailed)); + } updater.update_capacity(capacity)?; } ( @@ -160,202 +145,226 @@ impl<'a> QpackDecoder<'a> { updater.duplicate(index)?; } }, - None => return Result::Ok(()), + None => break, } } + + let insert_count = self.table.insert_count(); + + let unblocked = self + .blocked + .iter() + .filter_map(|(id, required)| { + if *required <= insert_count { + Some(*id) + } else { + None + } + }) + .collect::>(); + + self.blocked.retain(|_, required| *required > insert_count); + + Ok(unblocked) } - /// User call `decoder_repr` once for decoding a complete field section, which start with the `field section prefix`: - /// 0 1 2 3 4 5 6 7 - /// +---+---+---+---+---+---+---+---+ + /// User call `decoder_repr` once for decoding a complete field section, + /// which start with the `field section prefix`: 0 1 2 3 4 5 + /// 6 7 +---+---+---+---+---+---+---+---+ /// | Required Insert Count (8+) | /// +---+---------------------------+ /// | S | Delta Base (7+) | /// +---+---------------------------+ /// | Encoded Field Lines ... /// +-------------------------------+ - /// # Examples(not run) -// ```no_run -// use crate::ylong_http::h3::qpack::table::{DynamicTable, Field}; -// use crate::ylong_http::h3::qpack::decoder::QpackDecoder; -// use crate::ylong_http::test_util::decode; -// const MAX_HEADER_LIST_SIZE: usize = 16 << 20; -// // Required content: -// let mut dynamic_table = DynamicTable::with_empty(); -// let mut decoder = QpackDecoder::new(MAX_HEADER_LIST_SIZE, &mut dynamic_table); -// //decode field section -// // convert hex string to dec-array -// let mut repr = decode("03811011").unwrap().as_slice().to_vec(); -// decoder.decode_repr(&mut repr); -// ``` - pub fn decode_repr(&mut self, buf: &[u8]) -> Result<(), H3errorQpack> { + + pub(crate) fn decode_repr( + &mut self, + buf: &[u8], + stream_id: u64, + ) -> Result { + if self.blocked.contains_key(&stream_id) { + return Err(QpackError::InternalError(NotClassified::StreamBlocked)); + } + let mut message = match self.streams.remove(&stream_id) { + None => ReprMessage::new(), + Some(mut message) => { + if let Some(vec) = message.remaining.take() { + // A block cannot occur here because this is the stream expelled from the block. + self.decode_buffered_repr(vec.as_slice(), &mut message, stream_id)?; + } + message + } + }; + + self.decode_buffered_repr(buf, &mut message, stream_id) + .map(|state| { + self.streams.insert(stream_id, message); + state + }) + } + + fn decode_buffered_repr( + &mut self, + buf: &[u8], + message: &mut ReprMessage, + stream_id: u64, + ) -> Result { + if buf.is_empty() { + return Ok(Decoded); + } let mut decoder = ReprDecoder::new(); - let mut searcher = Searcher::new(self.field_list_size, self.table, &mut self.lines); + let mut searcher = + Searcher::new(self.max_field_section_size, &self.table, &mut message.lines); let mut cnt = 0; loop { - match decoder.decode(&buf[cnt..], &mut self.repr_state)? { - Some(( - offset, + match decoder.decode(&buf[cnt..], &mut message.repr_state)? { + Some((offset, repr)) => match repr { Representation::FieldSectionPrefix { require_insert_count, signal, delta_base, - }, - )) => { - cnt += offset; - if require_insert_count.0 == 0 { - self.require_insert_count = 0; - } else { - let max_entries = self.table.max_entries(); - let full_range = 2 * max_entries; - let max_value = self.table.insert_count + max_entries; - let max_wrapped = (max_value / full_range) * full_range; - self.require_insert_count = max_wrapped + require_insert_count.0 - 1; - if self.require_insert_count > max_value { - self.require_insert_count -= full_range; + } => { + cnt += offset; + if require_insert_count.0 == 0 { + message.require_insert_count = 0; + } else { + let max_entries = searcher.table.max_entries(); + let full_range = 2 * max_entries; + if require_insert_count.0 > full_range { + return Err(QpackError::ConnectionError(DecompressionFailed)); + } + let max_value = searcher.table.insert_count() + max_entries; + let max_wrapped = (max_value / full_range) * full_range; + message.require_insert_count = max_wrapped + require_insert_count.0 - 1; + + if message.require_insert_count > max_value { + if message.require_insert_count <= full_range { + return Err(QpackError::ConnectionError(DecompressionFailed)); + } + message.require_insert_count -= full_range; + } + if message.require_insert_count == 0 { + return Err(QpackError::ConnectionError(DecompressionFailed)); + } + } + if signal { + message.base = message.require_insert_count - delta_base.0 - 1; + } else { + message.base = message.require_insert_count + delta_base.0; + } + searcher.base = message.base; + if message.require_insert_count > searcher.table.insert_count() { + if self.blocked.len() > self.max_blocked_streams { + return Err(QpackError::ConnectionError(DecompressionFailed)); + } + self.blocked.insert(stream_id, message.require_insert_count); + message.remaining = Some(Vec::from(&buf[cnt..])); + return Ok(Blocked); } } - if signal { - self.base = self.require_insert_count - delta_base.0 - 1; - } else { - self.base = self.require_insert_count + delta_base.0; + Representation::Indexed { mid_bit, index } => { + cnt += offset; + searcher.search(Representation::Indexed { mid_bit, index })?; } - searcher.base = self.base; - if self.require_insert_count > self.table.insert_count { - //todo:block + Representation::IndexedWithPostIndex { index } => { + cnt += offset; + searcher.search(Representation::IndexedWithPostIndex { index })?; } - } - Some((offset, Representation::Indexed { mid_bit, index })) => { - cnt += offset; - searcher.search(Representation::Indexed { mid_bit, index })?; - } - Some((offset, Representation::IndexedWithPostIndex { index })) => { - cnt += offset; - searcher.search(Representation::IndexedWithPostIndex { index })?; - } - Some(( - offset, Representation::LiteralWithIndexing { mid_bit, name, value, - }, - )) => { - println!("offset:{}", offset); - cnt += offset; - searcher.search_literal_with_indexing(mid_bit, name, value)?; - } - Some(( - offset, + } => { + cnt += offset; + searcher.search_literal_with_indexing(mid_bit, name, value)?; + } + Representation::LiteralWithPostIndexing { mid_bit, name, value, - }, - )) => { - cnt += offset; - searcher.search_literal_with_post_indexing(mid_bit, name, value)?; - } - Some(( - offset, + } => { + cnt += offset; + searcher.search_literal_with_post_indexing(mid_bit, name, value)?; + } Representation::LiteralWithLiteralName { mid_bit, name, value, - }, - )) => { - cnt += offset; - searcher.search_listeral_with_literal(mid_bit, name, value)?; - } - + } => { + cnt += offset; + searcher.search_listeral_with_literal(mid_bit, name, value)?; + } + }, None => { - return Result::Ok(()); + return Ok(Decoded); } } } } - /// Users call `finish` to stop decoding a field section. And send an `Section Acknowledgment` to encoder: - /// After processing an encoded field section whose declared Required Insert Count is not zero, - /// the decoder emits a Section Acknowledgment instruction. The instruction starts with the - /// '1' 1-bit pattern, followed by the field section's associated stream ID encoded as - /// a 7-bit prefix integer + /// Users call `finish` to stop decoding a field section. And send an + /// `Section Acknowledgment` to encoder: After processing an encoded + /// field section whose declared Required Insert Count is not zero, + /// the decoder emits a Section Acknowledgment instruction. The instruction + /// starts with the '1' 1-bit pattern, followed by the field section's + /// associated stream ID encoded as a 7-bit prefix integer /// 0 1 2 3 4 5 6 7 /// +---+---+---+---+---+---+---+---+ /// | 1 | Stream ID (7+) | /// +---+---------------------------+ /// # Examples(not run) -// ```no_run -// use crate::ylong_http::h3::qpack::table::{DynamicTable, Field}; -// use crate::ylong_http::h3::qpack::decoder::QpackDecoder; -// use crate::ylong_http::test_util::decode; -// const MAX_HEADER_LIST_SIZE: usize = 16 << 20; -// // Required content: -// let mut dynamic_table = DynamicTable::with_empty(); -// let mut decoder = QpackDecoder::new(MAX_HEADER_LIST_SIZE, &mut dynamic_table); -// //decode field section -// // convert hex string to dec-array -// let mut repr = decode("03811011").unwrap().as_slice().to_vec(); -// decoder.decode_repr(&mut repr); -// //finish -// let mut qpack_decoder_buf = [0u8;20]; -// decoder.finish(1,&mut qpack_decoder_buf); -// ``` pub fn finish( &mut self, - stream_id: usize, - buf: &mut [u8], - ) -> Result<(Parts, Option), H3errorQpack> { - if self.repr_state.is_some() { - return Err(H3errorQpack::ConnectionError(DecompressionFailed)); - } - self.lines.header_size = 0; - if self.require_insert_count > 0 { - let ack = Integer::index(0x80, stream_id, 0x7f); - let size = ack.encode(buf); - if let Ok(size) = size { - return Ok((take(&mut self.lines.parts), Some(size))); + stream_id: u64, + buf: &mut Vec, + ) -> Result<(Parts, Option), QpackError> { + match self.streams.remove(&stream_id) { + None => Err(QpackError::ConnectionError(DecompressionFailed)), + Some(mut message) => { + if message.repr_state.is_some() { + return Err(QpackError::ConnectionError(DecompressionFailed)); + } + message.lines.header_size = 0; + if message.require_insert_count > 0 { + let ack = Integer::index(0x80, stream_id as usize, 0x7f); + + let mut res = Vec::new(); + ack.encode(&mut res); + buf.extend_from_slice(res.as_slice()); + return Ok((take(&mut message.lines.parts), Some(res.len()))); + } + Ok((take(&mut message.lines.parts), None)) } } - Ok((take(&mut self.lines.parts), None)) } - /// Users call `stream_cancel` to stop cancel a stream. And send an `Stream Cancellation` to encoder: - /// When a stream is reset or reading is abandoned, the decoder emits a Stream Cancellation - /// instruction. The instruction starts with the '01' 2-bit pattern, - /// followed by the stream ID of the affected stream encoded as a 6-bit prefix integer. + /// Users call `stream_cancel` to stop cancel a stream. And send an `Stream + /// Cancellation` to encoder: When a stream is reset or reading is + /// abandoned, the decoder emits a Stream Cancellation instruction. The + /// instruction starts with the '01' 2-bit pattern, followed by the + /// stream ID of the affected stream encoded as a 6-bit prefix integer. /// 0 1 2 3 4 5 6 7 /// +---+---+---+---+---+---+---+---+ /// | 0 | 1 | Stream ID (6+) | /// +---+---+-----------------------+ - /// # Examples(not run) -// ```no_run -// use crate::ylong_http::h3::qpack::table::{DynamicTable, Field}; -// use crate::ylong_http::h3::qpack::decoder::QpackDecoder; -// use crate::ylong_http::test_util::decode; -// const MAX_HEADER_LIST_SIZE: usize = 16 << 20; -// // Required content: -// let mut dynamic_table = DynamicTable::with_empty(); -// let mut decoder = QpackDecoder::new(MAX_HEADER_LIST_SIZE, &mut dynamic_table); -// //decode field section -// // convert hex string to dec-array -// let mut repr = decode("03811011").unwrap().as_slice().to_vec(); -// decoder.decode_repr(&mut repr); -// //stream_cancel -// let mut qpack_decoder_buf = [0u8;20]; -// decoder.stream_cancel(1,&mut qpack_decoder_buf); -// ``` - pub fn stream_cancel( - &mut self, - stream_id: usize, - buf: &mut [u8], - ) -> Result { - let ack = Integer::index(0x40, stream_id, 0x3f); - let size = ack.encode(buf); - if let Ok(size) = size { - return Ok(size); + pub fn stream_cancel(&mut self, stream_id: u64, buf: &mut [u8]) -> Result { + if self.table.capacity() > 0 { + self.blocked.remove(&stream_id); + self.streams.remove(&stream_id); + let ack = Integer::index(0x40, stream_id as usize, 0x3f); + let mut res = Vec::new(); + ack.encode(&mut res); + if res.len() > buf.len() { + Err(QpackError::ConnectionError(DecompressionFailed)) + } else { + buf[..res.len()].copy_from_slice(res.as_slice()); + Ok(res.len()) + } + } else { + Ok(0) } - Err(H3errorQpack::ConnectionError(DecompressionFailed)) } } @@ -368,7 +377,7 @@ impl<'a> Updater<'a> { Self { table } } - fn update_capacity(&mut self, capacity: usize) -> Result<(), H3errorQpack> { + fn update_capacity(&mut self, capacity: usize) -> Result<(), QpackError> { self.table.update_size(capacity); Ok(()) } @@ -378,18 +387,18 @@ impl<'a> Updater<'a> { mid_bit: MidBit, name: Name, value: Vec, - ) -> Result<(), H3errorQpack> { + ) -> Result<(), QpackError> { let (f, v) = - self.get_field_by_name_and_value(mid_bit, name, value, self.table.insert_count)?; + self.get_field_by_name_and_value(mid_bit, name, value, self.table.insert_count())?; self.table.update(f, v); Ok(()) } - fn duplicate(&mut self, index: usize) -> Result<(), H3errorQpack> { + fn duplicate(&mut self, index: usize) -> Result<(), QpackError> { let table_searcher = TableSearcher::new(self.table); let (f, v) = table_searcher - .find_field_dynamic(self.table.insert_count - index - 1) - .ok_or(H3errorQpack::ConnectionError(EncoderStreamError))?; + .find_field_dynamic(self.table.insert_count() - index - 1) + .ok_or(QpackError::ConnectionError(EncoderStreamError))?; self.table.update(f, v); Ok(()) } @@ -400,49 +409,53 @@ impl<'a> Updater<'a> { name: Name, value: Vec, insert_count: usize, - ) -> Result<(Field, String), H3errorQpack> { + ) -> Result<(NameField, String), QpackError> { let h = match name { Name::Index(index) => { let searcher = TableSearcher::new(self.table); if let Some(true) = mid_bit.t { searcher .find_field_name_static(index) - .ok_or(H3errorQpack::ConnectionError(EncoderStreamError))? + .ok_or(QpackError::ConnectionError(EncoderStreamError))? } else { searcher .find_field_name_dynamic(insert_count - index - 1) - .ok_or(H3errorQpack::ConnectionError(EncoderStreamError))? + .ok_or(QpackError::ConnectionError(EncoderStreamError))? } } - Name::Literal(octets) => Field::Other( + Name::Literal(octets) => NameField::Other( String::from_utf8(octets) - .map_err(|_| H3errorQpack::ConnectionError(EncoderStreamError))?, + .map_err(|_| QpackError::ConnectionError(EncoderStreamError))?, ), }; let v = String::from_utf8(value) - .map_err(|_| H3errorQpack::ConnectionError(EncoderStreamError))?; + .map_err(|_| QpackError::ConnectionError(EncoderStreamError))?; Ok((h, v)) } } struct Searcher<'a> { - field_list_size: usize, + max_field_section_size: usize, table: &'a DynamicTable, lines: &'a mut FiledLines, base: usize, } impl<'a> Searcher<'a> { - fn new(field_list_size: usize, table: &'a DynamicTable, lines: &'a mut FiledLines) -> Self { + fn new( + max_field_section_size: usize, + table: &'a DynamicTable, + lines: &'a mut FiledLines, + ) -> Self { Self { - field_list_size, + max_field_section_size, table, lines, base: 0, } } - fn search(&mut self, repr: Representation) -> Result<(), H3errorQpack> { + fn search(&mut self, repr: Representation) -> Result<(), QpackError> { match repr { Representation::Indexed { mid_bit, index } => self.search_indexed(mid_bit, index), Representation::IndexedWithPostIndex { index } => self.search_post_indexed(index), @@ -450,30 +463,30 @@ impl<'a> Searcher<'a> { } } - fn search_indexed(&mut self, mid_bit: MidBit, index: usize) -> Result<(), H3errorQpack> { + fn search_indexed(&mut self, mid_bit: MidBit, index: usize) -> Result<(), QpackError> { let table_searcher = TableSearcher::new(self.table); if let Some(true) = mid_bit.t { let (f, v) = table_searcher .find_field_static(index) - .ok_or(H3errorQpack::ConnectionError(DecompressionFailed))?; + .ok_or(QpackError::ConnectionError(DecompressionFailed))?; self.lines.parts.update(f, v); Ok(()) } else { let (f, v) = table_searcher .find_field_dynamic(self.base - index - 1) - .ok_or(H3errorQpack::ConnectionError(DecompressionFailed))?; + .ok_or(QpackError::ConnectionError(DecompressionFailed))?; self.lines.parts.update(f, v); Ok(()) } } - fn search_post_indexed(&mut self, index: usize) -> Result<(), H3errorQpack> { + fn search_post_indexed(&mut self, index: usize) -> Result<(), QpackError> { let table_searcher = TableSearcher::new(self.table); let (f, v) = table_searcher .find_field_dynamic(self.base + index) - .ok_or(H3errorQpack::ConnectionError(DecompressionFailed))?; + .ok_or(QpackError::ConnectionError(DecompressionFailed))?; self.check_field_list_size(&f, &v)?; self.lines.parts.update(f, v); Ok(()) @@ -484,7 +497,7 @@ impl<'a> Searcher<'a> { mid_bit: MidBit, name: Name, value: Vec, - ) -> Result<(), H3errorQpack> { + ) -> Result<(), QpackError> { let (f, v) = self.get_field_by_name_and_value( mid_bit, name, @@ -501,7 +514,7 @@ impl<'a> Searcher<'a> { mid_bit: MidBit, name: Name, value: Vec, - ) -> Result<(), H3errorQpack> { + ) -> Result<(), QpackError> { let (f, v) = self.get_field_by_name_and_value( mid_bit, name, @@ -518,7 +531,7 @@ impl<'a> Searcher<'a> { mid_bit: MidBit, name: Name, value: Vec, - ) -> Result<(), H3errorQpack> { + ) -> Result<(), QpackError> { let (h, v) = self.get_field_by_name_and_value( mid_bit, name, @@ -536,7 +549,7 @@ impl<'a> Searcher<'a> { name: Name, value: Vec, repr: ReprPrefixBit, - ) -> Result<(Field, String), H3errorQpack> { + ) -> Result<(NameField, String), QpackError> { let h = match name { Name::Index(index) => { if repr == ReprPrefixBit::LITERALWITHINDEXING { @@ -544,36 +557,37 @@ impl<'a> Searcher<'a> { if let Some(true) = mid_bit.t { searcher .find_field_name_static(index) - .ok_or(H3errorQpack::ConnectionError(DecompressionFailed))? + .ok_or(QpackError::ConnectionError(DecompressionFailed))? } else { searcher .find_field_name_dynamic(self.base - index - 1) - .ok_or(H3errorQpack::ConnectionError(DecompressionFailed))? + .ok_or(QpackError::ConnectionError(DecompressionFailed))? } } else { let searcher = TableSearcher::new(self.table); searcher .find_field_name_dynamic(self.base + index) - .ok_or(H3errorQpack::ConnectionError(DecompressionFailed))? + .ok_or(QpackError::ConnectionError(DecompressionFailed))? } } - Name::Literal(octets) => Field::Other( + Name::Literal(octets) => NameField::Other( String::from_utf8(octets) - .map_err(|_| H3errorQpack::ConnectionError(DecompressionFailed))?, + .map_err(|_| QpackError::ConnectionError(DecompressionFailed))?, ), }; let v = String::from_utf8(value) - .map_err(|_| H3errorQpack::ConnectionError(DecompressionFailed))?; + .map_err(|_| QpackError::ConnectionError(DecompressionFailed))?; Ok((h, v)) } pub(crate) fn update_size(&mut self, addition: usize) { self.lines.header_size += addition; } - fn check_field_list_size(&mut self, key: &Field, value: &String) -> Result<(), H3errorQpack> { + + fn check_field_list_size(&mut self, key: &NameField, value: &str) -> Result<(), QpackError> { let line_size = field_line_length(key.len(), value.len()); self.update_size(line_size); - if self.lines.header_size > self.field_list_size { - Err(H3errorQpack::ConnectionError(DecompressionFailed)) + if self.lines.header_size > self.max_field_section_size { + Err(QpackError::ConnectionError(DecompressionFailed)) } else { Ok(()) } @@ -583,391 +597,3 @@ impl<'a> Searcher<'a> { fn field_line_length(key_size: usize, value_size: usize) -> usize { key_size + value_size + 32 } - -#[cfg(test)] -mod ut_qpack_decoder { - use crate::h3::qpack::format::decoder::ReprDecodeState; - use crate::h3::qpack::table::{DynamicTable, Field}; - use crate::h3::qpack::QpackDecoder; - use crate::util::test_util::decode; - - const MAX_HEADER_LIST_SIZE: usize = 16 << 20; - - #[test] - fn ut_qpack_decoder() { - rfc9204_test_cases(); - test_need_more(); - test_indexed_static(); - test_indexed_dynamic(); - test_post_indexed_dynamic(); - test_literal_indexing_static(); - test_literal_indexing_dynamic(); - test_literal_post_indexing_dynamic(); - test_literal_with_literal_name(); - test_setcap(); - decode_long_field(); - - fn get_state(state: &Option) { - match state { - Some(ReprDecodeState::FiledSectionPrefix(_)) => { - println!("FiledSectionPrefix"); - } - Some(ReprDecodeState::ReprIndex(_)) => { - println!("Indexed"); - } - Some(ReprDecodeState::ReprValueString(_)) => { - println!("ReprValueString"); - } - Some(ReprDecodeState::ReprNameAndValue(_)) => { - println!("ReprNameAndValue"); - } - None => { - println!("None"); - } - } - } - macro_rules! check_pseudo { - ( - $pseudo: expr, - { $a: expr, $m: expr, $p: expr, $sc: expr, $st: expr } $(,)? - ) => { - assert_eq!($pseudo.authority(), $a); - assert_eq!($pseudo.method(), $m); - assert_eq!($pseudo.path(), $p); - assert_eq!($pseudo.scheme(), $sc); - assert_eq!($pseudo.status(), $st); - }; - } - - macro_rules! get_parts { - ($qpack: expr $(, $input: literal)*) => {{ - $( - let text = decode($input).unwrap().as_slice().to_vec(); - assert!($qpack.decode_repr(&text).is_ok()); - )* - let mut ack = [0u8; 20]; - match $qpack.finish(1,&mut ack) { - Ok((parts,_)) => parts, - Err(_) => panic!("QpackDecoder::finish() failed!"), - } - }}; - } - macro_rules! check_map { - ($map: expr, { $($(,)? $k: literal => $v: literal)* } $(,)?) => { - $( - assert_eq!($map.get($k).unwrap().to_string().unwrap(), $v); - )* - } - } - macro_rules! qpack_test_case { - ( - $qpack: expr $(, $input: literal)*, - { $a: expr, $m: expr, $p: expr, $sc: expr, $st: expr }, - { $size: expr $(, $($k2: literal)? $($k3: ident)? => $v2: literal)* } $(,)? - ) => { - let mut _qpack = $qpack; - let (pseudo, _) = get_parts!(_qpack $(, $input)*).into_parts(); - check_pseudo!(pseudo, { $a, $m, $p, $sc, $st }); - }; - ( - $qpack: expr $(, $input: literal)*, - { $($(,)? $k1: literal => $v1: literal)* }, - { $size: expr $(, $($k2: literal)? $($k3: ident)? => $v2: literal)* } $(,)? - ) => { - let mut _qpack = $qpack; - let (_, map) = get_parts!(_qpack $(, $input)*).into_parts(); - check_map!(map, { $($k1 => $v1)* }); - }; - ( - $hpack: expr $(, $input: literal)*, - { $a: expr, $m: expr, $p: expr, $sc: expr, $st: expr }, - { $($(,)? $k1: literal => $v1: literal)* }, - { $size: expr $(, $($k2: literal)? $($k3: ident)? => $v2: literal)* } $(,)? - ) => { - let mut _hpack = $hpack; - let (pseudo, map) = get_parts!(_hpack $(, $input)*).into_parts(); - check_pseudo!(pseudo, { $a, $m, $p, $sc, $st }); - check_map!(map, { $($k1 => $v1)* }); - }; - } - - fn rfc9204_test_cases() { - literal_field_line_with_name_reference(); - dynamic_table(); - speculative_insert(); - duplicate_instruction_stream_cancellation(); - dynamic_table_insert_eviction(); - fn literal_field_line_with_name_reference() { - println!("run literal_field_line_with_name_reference"); - let mut dynamic_table = DynamicTable::with_empty(); - dynamic_table.update_size(4096); - let decoder = QpackDecoder::new(MAX_HEADER_LIST_SIZE, &mut dynamic_table); - qpack_test_case!( - decoder, - "0000510b2f696e6465782e68746d6c", - { None, None, Some("/index.html"), None, None }, - { 0 } - ); - println!("passed"); - } - fn dynamic_table() { - println!("dynamic_table"); - let mut dynamic_table = DynamicTable::with_empty(); - dynamic_table.update_size(4096); - let mut decoder = QpackDecoder::new(MAX_HEADER_LIST_SIZE, &mut dynamic_table); - let ins = - decode("3fbd01c00f7777772e6578616d706c652e636f6dc10c2f73616d706c652f70617468") - .unwrap() - .as_slice() - .to_vec(); - let _ = decoder.decode_ins(&ins); - get_state(&decoder.repr_state); - qpack_test_case!( - decoder, - "03811011", - { Some("www.example.com"), None, Some("/sample/path"), None, None }, - { 0 } - ); - } - fn speculative_insert() { - let mut dynamic_table = DynamicTable::with_empty(); - dynamic_table.update_size(4096); - let mut decoder = QpackDecoder::new(MAX_HEADER_LIST_SIZE, &mut dynamic_table); - let ins = decode("4a637573746f6d2d6b65790c637573746f6d2d76616c7565") - .unwrap() - .as_slice() - .to_vec(); - let _ = decoder.decode_ins(&ins); - qpack_test_case!( - decoder, - "028010", - { "custom-key"=>"custom-value" }, - { 0 } - ); - } - fn duplicate_instruction_stream_cancellation() { - let mut dynamic_table = DynamicTable::with_empty(); - dynamic_table.update_size(4096); - dynamic_table.update(Field::Authority, String::from("www.example.com")); - dynamic_table.update(Field::Path, String::from("/sample/path")); - dynamic_table.update( - Field::Other(String::from("custom-key")), - String::from("custom-value"), - ); - dynamic_table.known_received_count = 3; - let mut decoder = QpackDecoder::new(MAX_HEADER_LIST_SIZE, &mut dynamic_table); - let ins = decode("02").unwrap().as_slice().to_vec(); - let _ = decoder.decode_ins(&ins); - qpack_test_case!( - decoder, - "058010c180", - { Some("www.example.com"), None, Some("/"), None, None }, - { "custom-key"=>"custom-value" }, - { 0 } - ); - } - fn dynamic_table_insert_eviction() { - let mut dynamic_table = DynamicTable::with_empty(); - dynamic_table.update_size(4096); - dynamic_table.update(Field::Authority, String::from("www.example.com")); - dynamic_table.update(Field::Path, String::from("/sample/path")); - dynamic_table.update( - Field::Other(String::from("custom-key")), - String::from("custom-value"), - ); - dynamic_table.update(Field::Authority, String::from("www.example.com")); - dynamic_table.known_received_count = 3; - let mut decoder = QpackDecoder::new(MAX_HEADER_LIST_SIZE, &mut dynamic_table); - let ins = decode("810d637573746f6d2d76616c756532") - .unwrap() - .as_slice() - .to_vec(); - let _ = decoder.decode_ins(&ins); - qpack_test_case!( - decoder, - "068111", - { "custom-key"=>"custom-value2" }, - { 0 } - ); - } - } - - fn test_need_more() { - println!("test_need_more"); - let mut dynamic_table = DynamicTable::with_empty(); - dynamic_table.update_size(4096); - let mut decoder = QpackDecoder::new(MAX_HEADER_LIST_SIZE, &mut dynamic_table); - let text = decode("00").unwrap().as_slice().to_vec(); //510b2f696e6465782e68746d6c - println!("text={:?}", text); - let _ = decoder.decode_repr(&text); - get_state(&decoder.repr_state); - let text2 = decode("00510b2f696e6465782e68746d6c") - .unwrap() - .as_slice() - .to_vec(); - println!("text2={:?}", text2); - let _ = decoder.decode_repr(&text2); - } - - fn test_indexed_static() { - let mut dynamic_table = DynamicTable::with_empty(); - dynamic_table.update_size(4096); - - qpack_test_case!( - QpackDecoder::new(MAX_HEADER_LIST_SIZE,&mut dynamic_table), - "0000d1", - { None, Some("GET"), None, None, None }, - { 0 } - ); - } - fn test_indexed_dynamic() { - // Test index "custom-field"=>"custom-value" in dynamic table - let mut dynamic_table = DynamicTable::with_empty(); - dynamic_table.update_size(4096); - //abs = 0 - dynamic_table.update( - Field::Other(String::from("custom-field")), - String::from("custom-value"), - ); - //abs = 1 - dynamic_table.update( - Field::Other(String::from("my-field")), - String::from("my-value"), - ); - qpack_test_case!( - QpackDecoder::new(MAX_HEADER_LIST_SIZE,&mut dynamic_table), - //require_insert_count=2, signal=false, delta_base=0 - //so base=2 - //rel_index=1 (abs=2(base)-1-1=0) - "030081", - {"custom-field"=>"custom-value"}, - { 0 } - ); - } - fn test_post_indexed_dynamic() { - // Test index "custom-field"=>"custom-value" in dynamic table - let mut dynamic_table = DynamicTable::with_empty(); - dynamic_table.update_size(4096); - //abs = 0 - dynamic_table.update( - Field::Other(String::from("custom1-field")), - String::from("custom1-value"), - ); - //abs = 1 - dynamic_table.update( - Field::Other(String::from("custom2-field")), - String::from("custom2-value"), - ); - //abs = 2 - dynamic_table.update( - Field::Other(String::from("custom3-field")), - String::from("custom3-value"), - ); - qpack_test_case!( - QpackDecoder::new(MAX_HEADER_LIST_SIZE,&mut dynamic_table), - //require_insert_count=3, signal=true, delta_base=2 - //so base = 3-2-1 = 0 - //rel_index=1 (abs=0(base)+1=1) - "048211", - {"custom2-field"=>"custom2-value"}, - { 0 } - ); - } - fn test_literal_indexing_static() { - let mut dynamic_table = DynamicTable::with_empty(); - dynamic_table.update_size(4096); - qpack_test_case!( - QpackDecoder::new(MAX_HEADER_LIST_SIZE,&mut dynamic_table), - "00007f020d637573746f6d312d76616c7565", - { None, Some("custom1-value"), None, None, None }, - { 0 } - ); - } - - fn test_literal_indexing_dynamic() { - let mut dynamic_table = DynamicTable::with_empty(); - dynamic_table.update_size(4096); - //abs = 0 - dynamic_table.update( - Field::Other(String::from("custom-field")), - String::from("custom-value"), - ); - //abs = 1 - dynamic_table.update( - Field::Other(String::from("my-field")), - String::from("my-value"), - ); - qpack_test_case!( - QpackDecoder::new(MAX_HEADER_LIST_SIZE,&mut dynamic_table), - //require_insert_count=2, signal=false, delta_base=0 - //so base=2 - //rel_index=1 (abs=2(base)-1-1=0) - "0300610d637573746f6d312d76616c7565", - {"custom-field"=>"custom1-value"}, - { 0 } - ); - } - - fn test_literal_post_indexing_dynamic() { - // Test index "custom-field"=>"custom-value" in dynamic table - let mut dynamic_table = DynamicTable::with_empty(); - dynamic_table.update_size(4096); - //abs = 0 - dynamic_table.update( - Field::Other(String::from("custom1-field")), - String::from("custom1-value"), - ); - //abs = 1 - dynamic_table.update( - Field::Other(String::from("custom2-field")), - String::from("custom2-value"), - ); - //abs = 2 - dynamic_table.update( - Field::Other(String::from("custom3-field")), - String::from("custom3-value"), - ); - qpack_test_case!( - QpackDecoder::new(MAX_HEADER_LIST_SIZE,&mut dynamic_table), - //require_insert_count=3, signal=true, delta_base=2 - //so base = 3-2-1 = 0 - //rel_index=1 (abs=0(base)+1=1) - "0482010d637573746f6d312d76616c7565", - {"custom2-field"=>"custom1-value"}, - { 0 } - ); - } - - fn test_literal_with_literal_name() { - let mut dynamic_table = DynamicTable::with_empty(); - dynamic_table.update_size(4096); - qpack_test_case!( - QpackDecoder::new(MAX_HEADER_LIST_SIZE, &mut dynamic_table), - "00003706637573746f6d322d76616c75650d637573746f6d312d76616c7565", - {"custom2-value"=>"custom1-value"}, - {0}, - ); - } - - fn test_setcap() { - let mut dynamic_table = DynamicTable::with_empty(); - dynamic_table.update_size(4096); - let mut decoder = QpackDecoder::new(MAX_HEADER_LIST_SIZE, &mut dynamic_table); - let ins = decode("3fbd01").unwrap().as_slice().to_vec(); - let _ = decoder.decode_ins(&ins); - assert_eq!(decoder.table.capacity(), 220); - } - - fn decode_long_field() { - let mut dynamic_table = DynamicTable::with_empty(); - dynamic_table.update_size(4096); - let mut decoder = QpackDecoder::new(MAX_HEADER_LIST_SIZE, &mut dynamic_table); - let repr = decode("ffffff01ffff01037fffff01") - .unwrap() - .as_slice() - .to_vec(); - let _ = decoder.decode_repr(&repr); - assert_eq!(decoder.base, 32382); - } - } -} diff --git a/ylong_http/src/h3/qpack/encoder.rs b/ylong_http/src/h3/qpack/encoder.rs index f1e2a23..e42c74d 100644 --- a/ylong_http/src/h3/qpack/encoder.rs +++ b/ylong_http/src/h3/qpack/encoder.rs @@ -11,580 +11,251 @@ // See the License for the specific language governing permissions and // limitations under the License. -#![rustfmt::skip] +use std::collections::{HashMap, HashSet, VecDeque}; use crate::h3::parts::Parts; use crate::h3::qpack::error::ErrorCode::DecoderStreamError; -use crate::h3::qpack::error::H3errorQpack; +use crate::h3::qpack::error::{ErrorCode, QpackError}; use crate::h3::qpack::format::encoder::{ DecInstDecoder, InstDecodeState, PartsIter, ReprEncodeState, SetCap, }; use crate::h3::qpack::format::ReprEncoder; use crate::h3::qpack::integer::{Integer, IntegerEncoder}; -use crate::h3::qpack::table::{DynamicTable, Field}; +use crate::h3::qpack::table::{DynamicTable, NameField}; use crate::h3::qpack::{DecoderInstruction, PrefixMask}; -use std::collections::{HashMap, VecDeque}; -/// An encoder is used to compress field in a compression format for efficiently representing -/// HTTP fields that is to be used in HTTP/3. This is a variation of HPACK compression that seeks -/// to reduce head-of-line blocking. -/// -/// # Examples -// ```no_run -// use crate::ylong_http::h3::qpack::encoder::QpackEncoder; -// use crate::ylong_http::h3::parts::Parts; -// use crate::ylong_http::h3::qpack::table::{DynamicTable, Field}; -// use crate::ylong_http::test_util::decode; -// -// -// -// // the (field, value) is: ("custom-key", "custom-value2") -// // Required content: -// let mut encoder_buf = [0u8; 1024]; // QPACK stream providing control commands. -// let mut stream_buf = [0u8; 1024]; // Field section encoded in QPACK format. -// let mut encoder_cur = 0; // index of encoder_buf. -// let mut stream_cur = 0; // index of stream_buf. -// let mut table = DynamicTable::with_empty(); -// -// // create a new encoder. -// let mut encoder = QpackEncoder::new(&mut table, 0, true, 1); -// -// // set dynamic table capacity. -// encoder_cur += encoder.set_capacity(220, &mut encoder_buf[encoder_buf..]); -// -// // set field section. -// let mut field = Parts::new(); -// field.update(Field::Other(String::from("custom-key")), String::from("custom-value")); -// encoder.set_parts(field); -// -// // encode field section. -// let (cur1, cur2, _) = encoder.encode(&mut encoder_buf[encoder_cur..], &mut stream_buf[stream_cur..]); -// encoder_cur += cur1; -// stream_cur += cur2; -// -// assert_eq!(stream_buf[..encoder_cur].to_vec().as_slice(), decode("028010").unwrap().as_slice()); -// assert_eq!(stream_buf[..stream_cur].to_vec().as_slice(), decode("4a637573746f6d2d6b65790c637573746f6d2d76616c7565").unwrap().as_slice()); -// -// -// ``` - -pub struct QpackEncoder<'a> { - table: &'a mut DynamicTable, +pub struct QpackEncoder { + max_blocked_streams: usize, + blocked_stream_nums: usize, + capacity_to_update: Option, + tracked_stream: HashMap, + table: DynamicTable, + is_huffman: bool, // Headers to be encode. field_iter: Option, // save the state of encoding field. field_state: Option, // save the state of decoding instructions. inst_state: Option, - // list of fields to be inserted. - insert_list: VecDeque<(Field, String)>, insert_length: usize, - // `RFC`: the number of insertions that the decoder needs to receive before it can decode the field section. + // `RFC`: the number of insertions that the decoder needs to receive before it can decode the + // field section. required_insert_count: usize, - - stream_id: usize, - // allow reference to the inserting field default is false. - allow_post: bool, - // RFC9204-2.1.1.1. if index QpackEncoder<'a> { +#[derive(Default)] +pub(crate) struct UnackFields { + unacked_section: VecDeque>, +} - /// create a new encoder. - /// #Examples -// ```no_run -// use ylong_http::h3::qpack::encoder::QpackEncoder; -// use ylong_http::h3::qpack::table::DynamicTable; -// let mut encoder_buf = [0u8; 1024]; // QPACK stream providing control commands. -// let mut stream_buf = [0u8; 1024]; // Field section encoded in QPACK format. -// let mut encoder_cur = 0; // index of encoder_buf. -// let mut stream_cur = 0; // index of stream_buf. -// let mut table = DynamicTable::with_empty(); -// -// // create a new encoder. -// let mut encoder = QpackEncoder::new(&mut table, 0, true, 1); -// ``` - pub fn new( - table: &'a mut DynamicTable, - stream_id: usize, - allow_post: bool, - draining_index: usize, - ) -> QpackEncoder { +impl UnackFields { + pub(crate) fn new(unacked: VecDeque>) -> Self { Self { - table, - field_iter: None, - field_state: None, - inst_state: None, - insert_list: VecDeque::new(), - insert_length: 0, - required_insert_count: 0, - stream_id, - allow_post, - draining_index, + unacked_section: unacked, } } + pub(crate) fn max_unacked_index(&self) -> Option { + self.unacked_section.iter().flatten().max().cloned() + } - /// Set the maximum dynamic table size. - /// # Examples - /// ```no_run - /// use ylong_http::h3::qpack::encoder::QpackEncoder; - /// use ylong_http::h3::qpack::table::DynamicTable; - /// let mut encoder_buf = [0u8; 1024]; // QPACK stream providing control commands. - /// let mut stream_buf = [0u8; 1024]; // Field section encoded in QPACK format. - /// let mut encoder_cur = 0; // index of encoder_buf. - /// let mut stream_cur = 0; // index of stream_buf. - /// let mut table = DynamicTable::with_empty(); - /// let mut encoder = QpackEncoder::new(&mut table, 0, true, 1); - /// let mut encoder_cur = encoder.set_capacity(220, &mut encoder_buf[..]); - /// ``` - pub fn set_capacity(&mut self, max_size: usize, encoder_buf: &mut [u8]) -> usize { - self.table.update_size(max_size); - if let Ok(cur) = SetCap::new(max_size).encode(&mut encoder_buf[..]) { - return cur; - } - 0 + pub(crate) fn unacked_section_mut(&mut self) -> &mut VecDeque> { + &mut self.unacked_section + } + + pub(crate) fn update(&mut self, unacked: HashSet) { + self.unacked_section.push_back(unacked); } +} +pub struct EncodeMessage { + fields: Vec, + inst: Vec, +} - /// Set the field section to be encoded. - /// # Examples -// ```no_run -// use ylong_http::h3::qpack::encoder::QpackEncoder; -// use ylong_http::h3::parts::Parts; -// use ylong_http::h3::qpack::table::{DynamicTable, Field}; -// let mut table = DynamicTable::with_empty(); -// let mut encoder = QpackEncoder::new(&mut table, 0, true, 1); -// let mut parts = Parts::new(); -// parts.update(Field::Other(String::from("custom-key")), String::from("custom-value")); -// encoder.set_parts(parts); -// ``` - pub fn set_parts(&mut self, parts: Parts) { - self.field_iter = Some(PartsIter::new(parts)); +impl EncodeMessage { + pub fn new(fields: Vec, inst: Vec) -> Self { + Self { fields, inst } + } + pub fn fields(&self) -> &Vec { + &self.fields } - fn ack(&mut self, stream_id: usize) -> Result, H3errorQpack> { - assert_eq!(stream_id, self.stream_id); + pub fn inst(&self) -> &Vec { + &self.inst + } +} - if self.table.known_received_count < self.required_insert_count { - self.table.known_received_count = self.required_insert_count; +impl QpackEncoder { + pub(crate) fn finish_stream(&self, id: u64) -> Result<(), QpackError> { + if self.tracked_stream.contains_key(&id) { + Err(QpackError::ConnectionError(ErrorCode::EncoderStreamError)) } else { - return Err(H3errorQpack::ConnectionError(DecoderStreamError)); + Ok(()) } + } - Ok(Some(DecoderInst::Ack)) + pub(crate) fn set_max_table_capacity(&mut self, max_cap: usize) -> Result<(), QpackError> { + const MAX_TABLE_CAPACITY: usize = (1 << 30) - 1; + if max_cap > MAX_TABLE_CAPACITY { + return Err(QpackError::ConnectionError(ErrorCode::H3SettingsError)); + } + self.capacity_to_update = Some(max_cap); + Ok(()) } - /// Users can call `decode_ins` multiple times to decode decoder instructions. - /// # Return - /// `Ok(None)` means that the decoder instruction is not complete. - /// # Examples -// ```no_run -// use ylong_http::h3::qpack::encoder::QpackEncoder; -// use ylong_http::h3::parts::Parts; -// use ylong_http::h3::qpack::table::{DynamicTable, Field}; -// use ylong_http::test_util::decode; -// let mut table = DynamicTable::with_empty(); -// let mut encoder = QpackEncoder::new(&mut table, 0, true, 1); -// let _ = encoder.decode_ins(&mut decode("80").unwrap().as_slice()); -// ``` - pub fn decode_ins(&mut self, buf: &[u8]) -> Result, H3errorQpack> { - let mut decoder = DecInstDecoder::new(buf); + pub(crate) fn set_max_blocked_stream_size(&mut self, max_blocked: usize) { + self.max_blocked_streams = max_blocked; + } - match decoder.decode(&mut self.inst_state)? { - Some(DecoderInstruction::Ack { stream_id }) => self.ack(stream_id), - //todo: stream cancel - Some(DecoderInstruction::StreamCancel { stream_id }) => { - assert_eq!(stream_id, self.stream_id); - Ok(Some(DecoderInst::StreamCancel)) - } - //todo: insert count increment - Some(DecoderInstruction::InsertCountIncrement { increment }) => { - self.table.known_received_count += increment; - Ok(Some(DecoderInst::InsertCountIncrement)) + fn update_max_dynamic_table_cap(&mut self, encoder_buf: &mut Vec) { + if let Some(new_cap) = self.capacity_to_update { + if self.table.update_capacity(new_cap).is_some() { + SetCap::new(new_cap).encode(encoder_buf); + self.capacity_to_update = None; } - None => Ok(None), } } - fn get_prefix(&self, prefix_buf: &mut [u8]) -> usize { - let mut cur_prefix = 0; - let mut wire_ric = 0; - if self.required_insert_count != 0 { - wire_ric = self.required_insert_count % (2 * self.table.max_entries()) + 1; + pub fn set_parts(&mut self, parts: Parts) { + self.field_iter = Some(PartsIter::new(parts)); + } + + fn ack(&mut self, stream_id: usize) -> Result<(), QpackError> { + let mut known_received = self.table.known_recved_count(); + if let Some(unacked) = self.tracked_stream.get_mut(&(stream_id as u64)) { + if let Some(unacked_index) = unacked.unacked_section_mut().pop_front() { + for index in unacked_index { + if (index as u64) > known_received { + known_received += 1; + } + self.table.untracked_field(index); + } + } + if unacked.unacked_section_mut().is_empty() { + self.tracked_stream.remove(&(stream_id as u64)); + } } + let increment = known_received - self.table.known_recved_count(); - cur_prefix += Integer::index(0x00, wire_ric, 0xff) - .encode(&mut prefix_buf[..]) - .unwrap_or(0); - let base = self.table.insert_count; - println!("base: {}", base); - println!("required_insert_count: {}", self.required_insert_count); - if base >= self.required_insert_count { - cur_prefix += Integer::index(0x00, base - self.required_insert_count, 0x7f) - .encode(&mut prefix_buf[cur_prefix..]) - .unwrap_or(0); - } else { - cur_prefix += Integer::index(0x80, self.required_insert_count - base - 1, 0x7f) - .encode(&mut prefix_buf[cur_prefix..]) - .unwrap_or(0); + if increment > 0 { + self.increase_insert_count(increment as usize) } - cur_prefix + Ok(()) } - /// Users can call `encode` multiple times to encode multiple complete field sections. - /// # Examples -// ```no_run -// use ylong_http::h3::qpack::encoder::QpackEncoder; -// use ylong_http::h3::parts::Parts; -// use ylong_http::h3::qpack::table::{DynamicTable, Field}; -// use ylong_http::test_util::decode; -// let mut encoder_buf = [0u8; 1024]; // QPACK stream providing control commands. -// let mut stream_buf = [0u8; 1024]; // Field section encoded in QPACK format. -// let mut encoder_cur = 0; // index of encoder_buf. -// let mut stream_cur = 0; // index of stream_buf. -// let mut table = DynamicTable::with_empty(); -// let mut encoder = QpackEncoder::new(&mut table, 0, true, 1); -// let mut parts = Parts::new(); -// parts.update(Field::Other(String::from("custom-key")), String::from("custom-value")); -// encoder.set_parts(parts); -// let (cur1, cur2, _) = encoder.encode(&mut encoder_buf[encoder_cur..], &mut stream_buf[stream_cur..]); -// encoder_cur += cur1; -// stream_cur += cur2; -// ``` - pub fn encode( - &mut self, - encoder_buf: &mut [u8], //instructions encoded results - stream_buf: &mut [u8], //headers encoded results - ) -> (usize, usize, Option<([u8; 1024], usize)>) { - let (mut cur_encoder, mut cur_stream) = (0, 0); - if self.is_finished() { - // denote an end of field section - // self.stream_reference.push_back(None); - //todo: size of prefix_buf - let mut prefix_buf = [0u8; 1024]; - let cur_prefix = self.get_prefix(&mut prefix_buf[0..]); - for (field, value) in self.insert_list.iter() { - self.table.update(field.clone(), value.clone()); - } - (cur_encoder, cur_stream, Some((prefix_buf, cur_prefix))) - } else { - let mut encoder = ReprEncoder::new( - self.table, - self.draining_index, - self.allow_post, - &mut self.insert_length, - ); - (cur_encoder, cur_stream) = encoder.encode( - &mut self.field_iter, - &mut self.field_state, - &mut encoder_buf[0..], - &mut stream_buf[0..], - &mut self.insert_list, - &mut self.required_insert_count, - ); - (cur_encoder, cur_stream, None) - } + + fn increase_insert_count(&mut self, increment: usize) { + self.table.increase_known_receive_count(increment); + self.update_blocked_stream(); } - /// Check the previously set `Parts` if encoding is complete. - pub(crate) fn is_finished(&self) -> bool { - self.field_iter.is_none() && self.field_state.is_none() + fn cancel_stream(&mut self, stream_id: u64) { + let mut stream_blocked = false; + if let Some(mut fields) = self.tracked_stream.remove(&stream_id) { + fields + .unacked_section_mut() + .iter() + .flatten() + .for_each(|index| { + self.table.untracked_field(*index); + if *index > (self.table.known_recved_count() as usize) { + stream_blocked = true; + } + }) + } + if stream_blocked { + self.blocked_stream_nums -= 1; + } } -} -pub enum DecoderInst { - Ack, - StreamCancel, - InsertCountIncrement, -} + fn update_blocked_stream(&mut self) { + let known_receive_cnt = self.table.known_recved_count() as usize; + let mut blocked = 0; + self.tracked_stream.iter_mut().for_each(|(_, fields)| { + if fields + .unacked_section_mut() + .iter() + .flatten() + .any(|index| *index > known_receive_cnt) + { + blocked += 1; + } + }); + self.blocked_stream_nums = blocked; + } -#[cfg(test)] -mod ut_qpack_encoder { - use crate::h3::parts::Parts; - use crate::h3::qpack::encoder; - use crate::h3::qpack::encoder::QpackEncoder; - use crate::h3::qpack::table::{DynamicTable, Field}; - use crate::util::test_util::decode; - macro_rules! qpack_test_cases { - ($enc: expr,$encoder_buf:expr,$encoder_cur:expr, $len: expr, $res: literal,$encoder_res: literal, $size: expr, { $($h: expr, $v: expr $(,)?)*} $(,)?) => { - let mut _encoder = $enc; - let mut stream_buf = [0u8; $len]; - let mut stream_cur = 0; - $( - let mut parts = Parts::new(); - parts.update($h, $v); - _encoder.set_parts(parts); - let (cur1,cur2,_) = _encoder.encode(&mut $encoder_buf[$encoder_cur..],&mut stream_buf[stream_cur..]); - $encoder_cur += cur1; - stream_cur += cur2; - )* - let (cur1, cur2, prefix) = _encoder.encode(&mut $encoder_buf[$encoder_cur..],&mut stream_buf[stream_cur..]); - $encoder_cur += cur1; - stream_cur += cur2; - if let Some((prefix_buf,cur_prefix)) = prefix{ - stream_buf.copy_within(0..stream_cur,cur_prefix); - stream_buf[..cur_prefix].copy_from_slice(&prefix_buf[..cur_prefix]); - stream_cur += cur_prefix; + pub fn decode_ins(&mut self, buf: &[u8]) -> Result<(), QpackError> { + let mut decoder = DecInstDecoder::new(buf); + loop { + match decoder.decode(&mut self.inst_state)? { + Some(DecoderInstruction::Ack { stream_id }) => self.ack(stream_id)?, + Some(DecoderInstruction::StreamCancel { stream_id }) => { + self.cancel_stream(stream_id as u64); } - println!("stream_buf: {:#?}",stream_buf); - let result = decode($res).unwrap(); - if let Some(res) = decode($encoder_res){ - assert_eq!($encoder_buf[..$encoder_cur].to_vec().as_slice(), res.as_slice()); + Some(DecoderInstruction::InsertCountIncrement { increment }) => { + self.increase_insert_count(increment); } - assert_eq!(stream_cur, $len); - assert_eq!(stream_buf.as_slice(), result.as_slice()); - assert_eq!(_encoder.table.size(), $size); + None => return Ok(()), } } - #[test] - /// The encoder sends an encoded field section containing a literal representation of a field - /// with a static name reference. - fn literal_field_line_with_name_reference() { - println!("literal_field_line_with_name_reference"); - let mut encoder_buf = [0u8; 1024]; - let mut table = DynamicTable::with_empty(); - let mut encoder = QpackEncoder::new(&mut table, 0, false, 0); - let mut encoder_cur = encoder.set_capacity(0, &mut encoder_buf[..]); - qpack_test_cases!( - encoder, - encoder_buf, - encoder_cur, - 15, "0000510b2f696e6465782e68746d6c", - "20", - 0, - { - Field::Path, - String::from("/index.html"), - }, - ); } - #[test] - ///The encoder sets the dynamic table capacity, inserts a header with a dynamic name - /// reference, then sends a potentially blocking, encoded field section referencing - /// this new entry. The decoder acknowledges processing the encoded field section, - /// which implicitly acknowledges all dynamic table insertions up to the Required - /// Insert Count. - fn dynamic_table() { - let mut encoder_buf = [0u8; 1024]; - let mut table = DynamicTable::with_empty(); - let mut encoder = QpackEncoder::new(&mut table, 0, true, 0); - let mut encoder_cur = encoder.set_capacity(220, &mut encoder_buf[..]); - qpack_test_cases!( - encoder, - encoder_buf, - encoder_cur, - 4, "03811011", - "3fbd01c00f7777772e6578616d706c652e636f6dc10c2f73616d706c652f70617468", - 106, - { - Field::Authority, - String::from("www.example.com"), - Field::Path, - String::from("/sample/path"), - }, - ); - } + pub fn encode(&mut self, stream_id: u64) -> EncodeMessage { + let mut fields = Vec::new(); + let mut inst = Vec::new(); + self.update_max_dynamic_table_cap(&mut inst); - #[test] - ///The encoder inserts a header into the dynamic table with a literal name. - /// The decoder acknowledges receipt of the entry. The encoder does not send any - /// encoded field sections. - fn speculative_insert() { - let mut encoder_buf = [0u8; 1024]; - let mut table = DynamicTable::with_empty(); - let mut encoder = QpackEncoder::new(&mut table, 0, true, 0); - let _ = encoder.set_capacity(220, &mut encoder_buf[..]); - let mut encoder_cur = 0; - qpack_test_cases!( - encoder, - encoder_buf, - encoder_cur, - 3, "028010", - "4a637573746f6d2d6b65790c637573746f6d2d76616c7565", - 54, - { - Field::Other(String::from("custom-key")), - String::from("custom-value"), - }, - ); - } + let stream_blocked = self + .tracked_stream + .get(&stream_id) + .map_or(false, |unacked| { + unacked + .max_unacked_index() + .map_or(false, |idx| (idx as u64) > self.table.known_recved_count()) + }); - #[test] - fn duplicate_instruction_stream_cancellation() { - let mut encoder_buf = [0u8; 1024]; - let mut table = DynamicTable::with_empty(); - let mut encoder = QpackEncoder::new(&mut table, 0, true, 1); - let _ = encoder.set_capacity(4096, &mut encoder_buf[..]); - encoder - .table - .update(Field::Authority, String::from("www.example.com")); - encoder - .table - .update(Field::Path, String::from("/sample/path")); - encoder.table.update( - Field::Other(String::from("custom-key")), - String::from("custom-value"), + let reach_max_block = self.reach_max_blocked(); + let mut encoder = ReprEncoder::new( + stream_id, + self.table.insert_count() as u64, + self.is_huffman, + &mut self.table, ); - encoder.required_insert_count = 3; - let mut encoder_cur = 0; - qpack_test_cases!( - encoder, - encoder_buf, - encoder_cur, - 5, "050080c181", - "02", - 274, - { - Field::Authority, - String::from("www.example.com"), - Field::Path, - String::from("/"), - Field::Other(String::from("custom-key")), - String::from("custom-value") - }, + encoder.iterate_encode_fields( + &mut self.field_iter, + &mut self.tracked_stream, + &mut self.blocked_stream_nums, + stream_blocked || !reach_max_block, + &mut fields, + &mut inst, ); + EncodeMessage::new(fields, inst) } - #[test] - fn dynamic_table_insert_eviction() { - let mut encoder_buf = [0u8; 1024]; - let mut table = DynamicTable::with_empty(); - let mut encoder = QpackEncoder::new(&mut table, 0, true, 1); - let _ = encoder.set_capacity(4096, &mut encoder_buf[..]); - encoder - .table - .update(Field::Authority, String::from("www.example.com")); - encoder - .table - .update(Field::Path, String::from("/sample/path")); - encoder.table.update( - Field::Other(String::from("custom-key")), - String::from("custom-value"), - ); - encoder - .table - .update(Field::Authority, String::from("www.example.com")); - encoder.required_insert_count = 3; //acked - let mut encoder_cur = 0; - qpack_test_cases!( - encoder, - encoder_buf, - encoder_cur, - 3, "040183", - "810d637573746f6d2d76616c756532", - 272, - { - Field::Other(String::from("custom-key")), - String::from("custom-value2") - }, - ); + pub(crate) fn reach_max_blocked(&self) -> bool { + self.blocked_stream_nums >= self.max_blocked_streams } +} - #[test] - fn test_ack() { - let mut encoder_buf = [0u8; 1024]; - let mut table = DynamicTable::with_empty(); - let mut encoder = QpackEncoder::new(&mut table, 0, true, 1); - let mut encoder_cur = encoder.set_capacity(4096, &mut encoder_buf[..]); - - let field_list: [(Field, String); 3] = [ - (Field::Authority, String::from("www.example.com")), - (Field::Path, String::from("/sample/path")), - ( - Field::Other(String::from("custom-key")), - String::from("custom-value"), - ), - ]; - let mut stream_cur = 0; - for (field, value) in field_list.iter() { - let mut parts = Parts::new(); - parts.update(field.clone(), value.clone()); - encoder.set_parts(parts); - let mut stream_buf = [0u8; 1024]; - let (cur1, cur2, _) = encoder.encode( - &mut encoder_buf[encoder_cur..], - &mut stream_buf[stream_cur..], - ); - encoder_cur += cur1; - stream_cur += cur2; +impl Default for QpackEncoder { + fn default() -> Self { + Self { + max_blocked_streams: 0, + blocked_stream_nums: 0, + table: DynamicTable::with_empty(), + tracked_stream: HashMap::new(), + field_iter: None, + field_state: None, + inst_state: None, + insert_length: 0, + required_insert_count: 0, + capacity_to_update: None, + is_huffman: true, } - let _ = encoder.decode_ins(decode("80").unwrap().as_slice()); - assert_eq!(encoder.table.known_received_count, 3); } +} - #[test] - fn encode_post_name() { - let mut encoder_buf = [0u8; 1024]; - let mut table = DynamicTable::with_empty(); - let mut encoder = QpackEncoder::new(&mut table, 0, true, 1); - let _ = encoder.set_capacity(60, &mut encoder_buf[..]); - let mut encoder_cur = 0; - let mut stream_buf = [0u8; 100]; - let mut stream_cur = 0; - let mut parts = Parts::new(); - parts.update( - Field::Other(String::from("custom-key")), - String::from("custom-value1"), - ); - encoder.set_parts(parts); - let (cur1, cur2, _) = encoder.encode( - &mut encoder_buf[encoder_cur..], - &mut stream_buf[stream_cur..], - ); - encoder_cur += cur1; - stream_cur += cur2; - let mut parts = Parts::new(); - parts.update( - Field::Other(String::from("custom-key")), - String::from("custom-value2"), - ); - encoder.set_parts(parts); - let (cur1, cur2, _) = encoder.encode( - &mut encoder_buf[encoder_cur..], - &mut stream_buf[stream_cur..], - ); - encoder_cur += cur1; - stream_cur += cur2; - assert_eq!( - [16, 0, 13, 99, 117, 115, 116, 111, 109, 45, 118, 97, 108, 117, 101, 50], - stream_buf[..stream_cur] - ); - assert_eq!( - [ - 74, 99, 117, 115, 116, 111, 109, 45, 107, 101, 121, 13, 99, 117, 115, 116, 111, - 109, 45, 118, 97, 108, 117, 101, 49 - ], - encoder_buf[..encoder_cur] - ) - } - #[test] - fn test_indexing_with_litreal() { - let mut encoder_buf = [0u8; 1024]; - let mut table = DynamicTable::with_empty(); - let mut encoder = QpackEncoder::new(&mut table, 0, false, 1); - let _ = encoder.set_capacity(60, &mut encoder_buf[..]); - let encoder_cur = 0; - let mut stream_buf = [0u8; 100]; - let mut stream_cur = 0; - let mut parts = Parts::new(); - parts.update( - Field::Other(String::from("custom-key")), - String::from("custom-value1"), - ); - encoder.set_parts(parts); - let (_, cur2, _) = encoder.encode( - &mut encoder_buf[encoder_cur..], - &mut stream_buf[stream_cur..], - ); - stream_cur += cur2; - - assert_eq!( - [ - 39, 3, 99, 117, 115, 116, 111, 109, 45, 107, 101, 121, 13, 99, 117, 115, 116, 111, - 109, 45, 118, 97, 108, 117, 101, 49 - ], - stream_buf[..stream_cur] - ); - } +pub(crate) enum DecoderInst { + Ack, + StreamCancel, + InsertCountIncrement, } diff --git a/ylong_http/src/h3/qpack/error.rs b/ylong_http/src/h3/qpack/error.rs index aa3029f..a4a36a3 100644 --- a/ylong_http/src/h3/qpack/error.rs +++ b/ylong_http/src/h3/qpack/error.rs @@ -10,8 +10,11 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -pub enum H3errorQpack { + +#[derive(Debug, Eq, PartialEq, Clone)] +pub enum QpackError { ConnectionError(ErrorCode), + InternalError(NotClassified), } #[derive(Debug, Eq, PartialEq, Clone)] @@ -21,4 +24,12 @@ pub enum ErrorCode { EncoderStreamError = 0x0201, DecoderStreamError = 0x0202, + + H3SettingsError = 0x0109, +} + +#[derive(Debug, Eq, PartialEq, Clone)] +pub enum NotClassified { + DynamicTableInsufficient, + StreamBlocked, } diff --git a/ylong_http/src/h3/qpack/format/decoder.rs b/ylong_http/src/h3/qpack/format/decoder.rs index 5c40db2..e9acd1f 100644 --- a/ylong_http/src/h3/qpack/format/decoder.rs +++ b/ylong_http/src/h3/qpack/format/decoder.rs @@ -11,10 +11,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#![rustfmt::skip] +use std::cmp::Ordering; +use std::marker::PhantomData; use crate::h3::qpack::error::ErrorCode::DecompressionFailed; -use crate::h3::qpack::error::{ErrorCode, H3errorQpack}; +use crate::h3::qpack::error::{ErrorCode, QpackError}; use crate::h3::qpack::format::decoder::DecResult::Error; use crate::h3::qpack::integer::IntegerDecoder; use crate::h3::qpack::{ @@ -22,8 +23,6 @@ use crate::h3::qpack::{ Representation, RequireInsertCount, }; use crate::huffman::HuffmanDecoder; -use std::cmp::Ordering; -use std::marker::PhantomData; pub(crate) struct EncInstDecoder; @@ -34,17 +33,11 @@ impl EncInstDecoder { /// Decodes `buf`. Every time users call `decode`, it will try to /// decode a `EncoderInstruction`. - /// # Example -// ```no_run -// use crate::h3::qpack::format::decoder::EncInstDecoder; -// let mut decoder = EncInstDecoder::new(); -// -// ``` pub(crate) fn decode( &mut self, buf: &[u8], inst_state: &mut Option, - ) -> Result, H3errorQpack> { + ) -> Result, QpackError> { if buf.is_empty() { return Ok(None); } @@ -58,7 +51,6 @@ impl EncInstDecoder { // `Representation`, `Ok(None)` will be returned. Users need to call // `save` to save the current state to a `ReprDecStateHolder`. DecResult::NeedMore(state) => { - println!("need more"); *inst_state = Some(state); Ok(None) } @@ -83,7 +75,7 @@ impl ReprDecoder { &mut self, buf: &[u8], repr_state: &mut Option, - ) -> Result, H3errorQpack> { + ) -> Result, QpackError> { // If buf is empty, leave the state unchanged. let buf_len = buf.len(); if buf.is_empty() { @@ -98,7 +90,6 @@ impl ReprDecoder { // `Representation`, `Ok(None)` will be returned. Users need to call // `save` to save the current state to a `ReprDecStateHolder`. DecResult::NeedMore(state) => { - println!("need more"); *repr_state = Some(state); Ok(None) } @@ -109,7 +100,6 @@ impl ReprDecoder { *repr_state = Some(ReprDecodeState::ReprIndex(ReprIndex::new())); Ok(Some((buf_index, repr))) } - DecResult::Error(error) => Err(error), } } @@ -238,17 +228,17 @@ impl EncInstIndex { } fn decode(self, buf: &[u8]) -> DecResult<(usize, EncoderInstruction), InstDecodeState> { match self.inner.decode(buf) { - DecResult::Decoded((buf_index, EncoderInstPrefixBit::SETCAP, _, index)) => { + DecResult::Decoded((buf_index, EncoderInstPrefixBit::SET_CAP, _, index)) => { DecResult::Decoded((buf_index, EncoderInstruction::SetCap { capacity: index })) } DecResult::Decoded(( buf_index, - EncoderInstPrefixBit::INSERTWITHINDEX, + EncoderInstPrefixBit::INSERT_WITH_INDEX, mid_bit, index, )) => { let res = InstValueString::new( - EncoderInstPrefixBit::INSERTWITHINDEX, + EncoderInstPrefixBit::INSERT_WITH_INDEX, mid_bit, Name::Index(index), ) @@ -257,12 +247,12 @@ impl EncInstIndex { } DecResult::Decoded(( buf_index, - EncoderInstPrefixBit::INSERTWITHLITERAL, + EncoderInstPrefixBit::INSERT_WITH_LITERAL, mid_bit, namelen, )) => { let res = InstNameAndValue::new( - EncoderInstPrefixBit::INSERTWITHLITERAL, + EncoderInstPrefixBit::INSERT_WITH_LITERAL, mid_bit, namelen, ) @@ -276,9 +266,7 @@ impl EncInstIndex { DecResult::NeedMore(EncInstIndex::from_inner(inner).into()) } DecResult::Error(e) => e.into(), - _ => DecResult::Error(H3errorQpack::ConnectionError( - ErrorCode::DecompressionFailed, - )), + _ => DecResult::Error(QpackError::ConnectionError(ErrorCode::DecompressionFailed)), } } } @@ -338,9 +326,7 @@ impl ReprIndex { } DecResult::NeedMore(inner) => DecResult::NeedMore(ReprIndex::from_inner(inner).into()), DecResult::Error(e) => e.into(), - _ => DecResult::Error(H3errorQpack::ConnectionError( - ErrorCode::DecompressionFailed, - )), + _ => DecResult::Error(QpackError::ConnectionError(ErrorCode::DecompressionFailed)), } } } @@ -357,7 +343,7 @@ impl FSPTwoIntergers { return DecResult::NeedMore(self.into()); } let buf_len = buf.len(); - let mask = PrefixMask::REQUIREINSERTCOUNT; + let mask = PrefixMask::REQUIRE_INSERT_COUNT; let ric = match IntegerDecoder::first_byte(buf[buf_index - 1], mask.0) { Ok(ric) => ric, Err(mut int) => { @@ -388,7 +374,7 @@ impl FSPTwoIntergers { } let byte = buf[buf_index - 1]; let signal = (byte & 0x80) != 0; - let mask = PrefixMask::DELTABASE; + let mask = PrefixMask::DELTA_BASE; let delta_base = match IntegerDecoder::first_byte(byte, mask.0) { Ok(delta_base) => delta_base, Err(mut int) => { @@ -445,7 +431,8 @@ macro_rules! decode_first_byte { match IntegerDecoder::first_byte(buf[buf_index - 1], mask.0) { // Return the PrefixBit and index part value. Ok(idx) => DecResult::Decoded((buf_index, prefix, mid_bit, idx)), - // Index part value is longer than index(i.e. use all 1 to represent), so it needs more bytes to decode. + // Index part value is longer than index(i.e. use all 1 to represent), so it + // needs more bytes to decode. Err(int) => { let res = <$trailing_bytes>::new(prefix, mid_bit, int).decode(&buf[buf_index..]); @@ -608,7 +595,11 @@ macro_rules! name_and_value_decoder { impl $struct_name { fn new(prefix: $prefix_type, mid_bit: MidBit, namelen: usize) -> Self { - Self::from_inner(prefix, mid_bit, AsciiStringBytes::new(namelen).into()) + if mid_bit.h.is_some_and(|h| h) { + Self::from_inner(prefix, mid_bit, HuffmanStringBytes::new(namelen).into()) + } else { + Self::from_inner(prefix, mid_bit, AsciiStringBytes::new(namelen).into()) + } } fn from_inner(prefix: $prefix_type, mid_bit: MidBit, inner: $inner_type) -> Self { @@ -703,17 +694,17 @@ impl HuffmanStringBytes { Ordering::Greater | Ordering::Equal => { let pos = self.length - self.read; if self.huffman.decode(&buf[..pos]).is_err() { - return H3errorQpack::ConnectionError(DecompressionFailed).into(); + return QpackError::ConnectionError(DecompressionFailed).into(); } // let (_, mut remain_buf) = buf.split_at_mut(pos); match self.huffman.finish() { Ok(vec) => DecResult::Decoded((pos, vec)), - Err(_) => H3errorQpack::ConnectionError(DecompressionFailed).into(), + Err(_) => QpackError::ConnectionError(DecompressionFailed).into(), } } Ordering::Less => { if self.huffman.decode(buf).is_err() { - return H3errorQpack::ConnectionError(DecompressionFailed).into(); + return QpackError::ConnectionError(DecompressionFailed).into(); } self.read += buf.len(); // let (_, mut remain_buf) = buf.split_at_mut(buf.len()); @@ -757,7 +748,7 @@ impl InstValueString { fn decode(self, buf: &[u8]) -> DecResult<(usize, EncoderInstruction), InstDecodeState> { match (self.inst, self.inner.decode(buf)) { - (EncoderInstPrefixBit::INSERTWITHINDEX, DecResult::Decoded((buf_index, value))) => { + (EncoderInstPrefixBit::INSERT_WITH_INDEX, DecResult::Decoded((buf_index, value))) => { DecResult::Decoded(( buf_index, EncoderInstruction::InsertWithIndex { @@ -767,7 +758,7 @@ impl InstValueString { }, )) } - (EncoderInstPrefixBit::INSERTWITHLITERAL, DecResult::Decoded((buf_index, value))) => { + (EncoderInstPrefixBit::INSERT_WITH_LITERAL, DecResult::Decoded((buf_index, value))) => { DecResult::Decoded(( buf_index, EncoderInstruction::InsertWithLiteral { @@ -777,7 +768,7 @@ impl InstValueString { }, )) } - (_, _) => Error(H3errorQpack::ConnectionError(DecompressionFailed)), + (_, _) => Error(QpackError::ConnectionError(DecompressionFailed)), } } } @@ -835,7 +826,7 @@ impl ReprValueString { }, )) } - (_, _) => Error(H3errorQpack::ConnectionError(DecompressionFailed)), + (_, _) => Error(QpackError::ConnectionError(DecompressionFailed)), } } } @@ -851,11 +842,11 @@ pub(crate) enum DecResult { NeedMore(S), /// Errors that may occur when decoding. - Error(H3errorQpack), + Error(QpackError), } -impl From for DecResult { - fn from(e: H3errorQpack) -> Self { +impl From for DecResult { + fn from(e: QpackError) -> Self { DecResult::Error(e) } } diff --git a/ylong_http/src/h3/qpack/format/encoder.rs b/ylong_http/src/h3/qpack/format/encoder.rs index 4261422..658ec38 100644 --- a/ylong_http/src/h3/qpack/format/encoder.rs +++ b/ylong_http/src/h3/qpack/format/encoder.rs @@ -11,319 +11,261 @@ // See the License for the specific language governing permissions and // limitations under the License. -#![rustfmt::skip] +use std::arch::asm; +use std::cmp::{max, Ordering}; +use std::collections::{HashMap, HashSet, VecDeque}; +use std::sync::Arc; +use std::{mem, result}; use crate::h3::parts::Parts; +use crate::h3::pseudo::PseudoHeaders; +use crate::h3::qpack::encoder::{EncodeMessage, UnackFields}; use crate::h3::qpack::error::ErrorCode::DecoderStreamError; -use crate::h3::qpack::error::H3errorQpack; -use crate::h3::qpack::format::decoder::DecResult; +use crate::h3::qpack::error::QpackError; +use crate::h3::qpack::format::decoder::{DecResult, LiteralString}; use crate::h3::qpack::integer::{Integer, IntegerDecoder, IntegerEncoder}; -use crate::h3::qpack::table::{DynamicTable, Field, TableIndex, TableSearcher}; +use crate::h3::qpack::table::{DynamicTable, NameField, SearchResult, TableIndex, TableSearcher}; use crate::h3::qpack::{DecoderInstPrefixBit, DecoderInstruction, EncoderInstruction, PrefixMask}; use crate::headers::HeadersIntoIter; -use crate::h3::pseudo::PseudoHeaders; -use std::arch::asm; -use std::cmp::{max, Ordering}; -use std::collections::{HashMap, VecDeque}; -use std::result; -use std::sync::Arc; +use crate::huffman::huffman_encode; pub struct ReprEncoder<'a> { + stream_id: u64, + base: u64, + is_huffman: bool, table: &'a mut DynamicTable, - draining_index: usize, - allow_post: bool, - insert_length: &'a mut usize, } impl<'a> ReprEncoder<'a> { - /// Creates a new, empty `ReprEncoder`. - /// # Examples -// ```no_run -// use ylong_http::h3::qpack::table::DynamicTable; -// use ylong_http::h3::qpack::format::encoder::ReprEncoder; -// let mut table = DynamicTable::new(4096); -// let mut insert_length = 0; -// let mut encoder = ReprEncoder::new(&mut table, 0, true, &mut insert_length); -// ``` - pub fn new( - table: &'a mut DynamicTable, - draining_index: usize, - allow_post: bool, - insert_length: &'a mut usize, - ) -> Self { + pub fn new(stream_id: u64, base: u64, is_huffman: bool, table: &'a mut DynamicTable) -> Self { Self { + stream_id, + base, + is_huffman, table, - draining_index, - allow_post, - insert_length, - } - } - - /// written to `buffer` and the length of the decoded content will be returned. - /// # Examples -// ```no_run -// use std::collections::VecDeque;use ylong_http::h3::qpack::table::DynamicTable; -// use ylong_http::h3::qpack::format::encoder::ReprEncoder; -// let mut table = DynamicTable::new(4096); -// let mut insert_length = 0; -// let mut encoder = ReprEncoder::new(&mut table, 0, true, &mut insert_length); -// let mut qpack_buffer = [0u8; 1024]; -// let mut stream_buffer = [0u8; 1024]; // stream buffer -// let mut insert_list = VecDeque::new(); // fileds to insert -// let mut required_insert_count = 0; // RFC required. -// let mut field_iter = None; // for field iterator -// let mut field_state = None; // for field encode state -// encoder.encode(&mut field_iter, &mut field_state, &mut qpack_buffer, &mut stream_buffer, &mut insert_list, &mut required_insert_count); - pub(crate) fn encode( + } + } + + fn get_prefix(&self, max_ref: usize, prefix_buf: &mut Vec) { + let required_insert_count = max_ref + 1; + + let mut wire_ric = 0; + if max_ref != 0 { + wire_ric = required_insert_count % (2 * self.table.max_entries()) + 1; + Integer::index(0x00, wire_ric, 0xff).encode(prefix_buf); + let base = self.base as usize; + if base >= required_insert_count { + Integer::index(0x00, base - required_insert_count, 0x7f).encode(prefix_buf); + } else { + Integer::index(0x80, required_insert_count - base - 1, 0x7f).encode(prefix_buf); + } + } else { + Integer::index(0x00, wire_ric, 0xff).encode(prefix_buf); + Integer::index(0x00, 0, 0x7f).encode(prefix_buf); + } + } + + pub(crate) fn iterate_encode_fields( &mut self, field_iter: &mut Option, - field_state: &mut Option, - encoder_buffer: &mut [u8], - stream_buffer: &mut [u8], - insert_list: &mut VecDeque<(Field, String)>, - required_insert_count: &mut usize, - ) -> (usize, usize) { - let mut cur_encoder = 0; - let mut cur_stream = 0; - let mut base = self.table.insert_count; + track_map: &mut HashMap, + blocked_cnt: &mut usize, + allow_block: bool, + fields: &mut Vec, + inst: &mut Vec, + ) { + let mut max_dynamic = 0; + + let mut ref_fields = HashSet::::new(); + if let Some(mut iter) = field_iter.take() { - while let Some((h, v)) = iter.next() { - let searcher = TableSearcher::new(self.table); - let mut stream_result: Result = Result::Ok(0); - let mut encoder_result: Result = Result::Ok(0); - let static_index = searcher.find_index_static(&h, &v); - if static_index != Some(TableIndex::None) { - if let Some(TableIndex::Field(index)) = static_index { - // Encode as index in static table - stream_result = - Indexed::new(index, true).encode(&mut stream_buffer[cur_stream..]); - } - } else { - let mut dynamic_index = searcher.find_index_dynamic(&h, &v); - let static_name_index = searcher.find_index_name_static(&h, &v); - let mut dynamic_name_index = Some(TableIndex::None); - if dynamic_index == Some(TableIndex::None) || !self.should_index(&dynamic_index) - { - // if index is close to eviction, drop it and use duplicate - // let dyn_index = dynamic_index.clone(); - // dynamic_index = Some(TableIndex::None); - let mut is_duplicate = false; - if static_name_index == Some(TableIndex::None) { - dynamic_name_index = searcher.find_index_name_dynamic(&h, &v); - } - - if self.table.have_enough_space(&h, &v, self.insert_length) { - if !self.should_index(&dynamic_index) { - if let Some(TableIndex::Field(index)) = dynamic_index { - encoder_result = Duplicate::new(base - index - 1) - .encode(&mut encoder_buffer[cur_encoder..]); - self.table.update(h.clone(), v.clone()); - base = max(base, self.table.insert_count); - dynamic_index = - Some(TableIndex::Field(self.table.insert_count - 1)); - is_duplicate = true; - } - } else { - encoder_result = match ( - &static_name_index, - &dynamic_name_index, - self.should_index(&dynamic_name_index), - ) { - // insert with name reference in static table - (Some(TableIndex::FieldName(index)), _, _) => { - InsertWithName::new( - *index, - v.clone().into_bytes(), - false, - true, - ) - .encode(&mut encoder_buffer[cur_encoder..]) - } - // insert with name reference in dynamic table - (_, Some(TableIndex::FieldName(index)), true) => { - // convert abs index to rel index - InsertWithName::new( - base - index - 1, - v.clone().into_bytes(), - false, - false, - ) - .encode(&mut encoder_buffer[cur_encoder..]) - } - // Duplicate - (_, Some(TableIndex::FieldName(index)), false) => { - let res = Duplicate::new(*index) - .encode(&mut encoder_buffer[cur_encoder..]); - self.table.update(h.clone(), v.clone()); - base = max(base, self.table.insert_count); - dynamic_name_index = Some(TableIndex::FieldName( - self.table.insert_count - 1, - )); - is_duplicate = true; - res - } - // insert with literal name - (_, _, _) => InsertWithLiteral::new( - h.clone().into_string().into_bytes(), - v.clone().into_bytes(), - false, - ) - .encode(&mut encoder_buffer[cur_encoder..]), - } - }; - if self.table.size() + h.len() + v.len() + 32 >= self.table.capacity() { - self.draining_index += 1; - } - insert_list.push_back((h.clone(), v.clone())); - *self.insert_length += h.len() + v.len() + 32; - } - if self.allow_post && !is_duplicate { - for (post_index, (t_h, t_v)) in insert_list.iter().enumerate() { - if t_h == &h && t_v == &v { - dynamic_index = Some(TableIndex::Field(post_index)) - } - if t_h == &h { - dynamic_name_index = Some(TableIndex::FieldName(post_index)); - } - } - } - } + while let Some((field_h, field_v)) = iter.next() { + if let Some(dyn_ref) = + self.encode_field(&mut ref_fields, field_h, field_v, inst, fields, allow_block) + { + max_dynamic = max(max_dynamic, dyn_ref); + } + } + } + + let enc_field_lines = mem::take(fields); + + self.get_prefix(max_dynamic, fields); + fields.extend_from_slice(enc_field_lines.as_slice()); + if (max_dynamic as u64) > self.table.known_recved_count() { + *blocked_cnt += 1; + } + + if !ref_fields.is_empty() { + track_map + .entry(self.stream_id) + .or_default() + .update(ref_fields); + } + } - if dynamic_index == Some(TableIndex::None) { - if dynamic_name_index != Some(TableIndex::None) { - //Encode with name reference in dynamic table - if let Some(TableIndex::FieldName(index)) = dynamic_name_index { - // use post-base index - if base <= index { - stream_result = IndexingWithPostName::new( - index - base, - v.clone().into_bytes(), - false, - false, - ) - .encode(&mut stream_buffer[cur_stream..]); - } else { - stream_result = IndexingWithName::new( - base - index - 1, - v.clone().into_bytes(), - false, - false, - false, - ) - .encode(&mut stream_buffer[cur_stream..]); - } - *required_insert_count = max(*required_insert_count, index + 1); - } - } else { - // Encode with name reference in static table - // or Encode as Literal - if static_name_index != Some(TableIndex::None) { - if let Some(TableIndex::FieldName(index)) = static_name_index { - stream_result = IndexingWithName::new( - index, - v.into_bytes(), - false, - true, - false, - ) - .encode(&mut stream_buffer[cur_stream..]); - } - } else { - stream_result = IndexingWithLiteral::new( - h.into_string().into_bytes(), - v.into_bytes(), - false, - false, - ) - .encode(&mut stream_buffer[cur_stream..]); - } - } + fn encode_field( + &mut self, + ref_fields: &mut HashSet, + field_h: NameField, + field_v: String, + inst: &mut Vec, + fields: &mut Vec, + allow_block: bool, + ) -> Option { + let mut dynamic_index = None; + match self.search_filed_from_table(&field_h, field_v.as_str(), allow_block) { + SearchResult::StaticIndex(index) => { + Indexed::new(index, true).encode(fields); + } + SearchResult::StaticNameIndex(index) => { + IndexingWithName::new(index, field_v.into_bytes(), self.is_huffman, true, false) + .encode(fields); + } + SearchResult::DynamicIndex(index) => { + self.indexed_dynamic_field(index, fields); + if ref_fields.insert(index) { + self.table.track_field(index); + } + dynamic_index = Some(index); + } + SearchResult::DynamicNameIndex(index) => { + self.indexed_name_ref_with_literal_field(index, field_v, self.is_huffman, fields); + if ref_fields.insert(index) { + self.table.track_field(index) + } + dynamic_index = Some(index); + } + SearchResult::NotFound => { + // allow_block允许插入或引用未确认的index + if allow_block { + if let Some(index) = self.add_field_to_dynamic( + field_h.clone(), + field_v.clone(), + inst, + self.is_huffman, + ) { + self.indexed_dynamic_field(index, fields); + ref_fields.insert(index); + self.table.track_field(index); + dynamic_index = Some(index); } else { - assert!(dynamic_index != Some(TableIndex::None)); - // Encode with index in dynamic table - if let Some(TableIndex::Field(index)) = dynamic_index { - // use post-base index - if base <= index { - stream_result = IndexedWithPostName::new(index - base) - .encode(&mut stream_buffer[cur_stream..]); - } else { - stream_result = Indexed::new(base - index - 1, false) - .encode(&mut stream_buffer[cur_stream..]); - } - *required_insert_count = max(*required_insert_count, index + 1); - } + IndexingWithLiteral::new( + field_h.to_string().into_bytes(), + field_v.into_bytes(), + self.is_huffman, + false, + ) + .encode(fields); } + } else { + IndexingWithLiteral::new( + field_h.to_string().into_bytes(), + field_v.into_bytes(), + self.is_huffman, + false, + ) + .encode(fields); } + } + } + dynamic_index + } - match (encoder_result, stream_result) { - (Ok(encoder_size), Ok(stream_size)) => { - cur_stream += stream_size; - cur_encoder += encoder_size; - } - (Err(state), Ok(_)) => { - *field_iter = Some(iter); - *field_state = Some(state); - return (encoder_buffer.len(), stream_buffer.len()); - } - (Ok(_), Err(state)) => { - *field_iter = Some(iter); - *field_state = Some(state); - return (encoder_buffer.len(), stream_buffer.len()); - } - (Err(_), Err(state)) => { - *field_iter = Some(iter); - *field_state = Some(state); - return (encoder_buffer.len(), stream_buffer.len()); - } - } + fn indexed_dynamic_field(&self, index: usize, fields: &mut Vec) { + if index as u64 >= self.base { + IndexedWithPost::new(index - (self.base as usize)).encode(fields); + } else { + Indexed::new((self.base as usize) - index - 1, false).encode(fields); + } + } + + fn indexed_name_ref_with_literal_field( + &self, + index: usize, + field_v: String, + is_huffman: bool, + fields: &mut Vec, + ) { + if index as u64 >= self.base { + IndexingWithPostName::new( + index - (self.base as usize), + field_v.into_bytes(), + is_huffman, + false, + ) + .encode(fields); + } else { + InsertWithName::new( + (self.base as usize) - index - 1, + field_v.into_bytes(), + is_huffman, + false, + ) + .encode(fields); + } + } + + fn add_field_to_dynamic( + &mut self, + field_h: NameField, + field_v: String, + inst: &mut Vec, + is_huffman: bool, + ) -> Option { + let header_size = field_h.len() + field_v.len() + 32; + if let Some(len) = self.table.can_evict(header_size) { + InsertWithLiteral::new( + field_h.to_string().into_bytes(), + field_v.clone().into_bytes(), + is_huffman, + ) + .encode(inst); + self.table.evict_drained(len); + Some(self.table.update(field_h, field_v)) + } else { + IndexingWithLiteral::new( + field_h.to_string().into_bytes(), + field_v.into_bytes(), + is_huffman, + false, + ) + .encode(inst); + None + } + } + + fn search_filed_from_table( + &mut self, + h: &NameField, + v: &str, + allow_block: bool, + ) -> SearchResult { + let searcher = TableSearcher::new(self.table); + let mut search_result = SearchResult::NotFound; + match searcher.search_in_static(h, v) { + TableIndex::Field(index) => { + return SearchResult::StaticIndex(index); } + TableIndex::FieldName(index) => { + search_result = SearchResult::StaticNameIndex(index); + } + TableIndex::None => {} } - (cur_encoder, cur_stream) - } - // ## 2.1.1.1. Avoiding Prohibited Insertions - // To ensure that the encoder is not prevented from adding new entries, the encoder can - // avoid referencing entries that are close to eviction. Rather than reference such an - // entry, the encoder can emit a Duplicate instruction (Section 4.3.4) and reference - // the duplicate instead. - // - // Determining which entries are too close to eviction to reference is an encoder preference. - // One heuristic is to target a fixed amount of available space in the dynamic table: - // either unused space or space that can be reclaimed by evicting non-blocking entries. - // To achieve this, the encoder can maintain a draining index, which is the smallest - // absolute index (Section 3.2.4) in the dynamic table that it will emit a reference for. - // As new entries are inserted, the encoder increases the draining index to maintain the - // section of the table that it will not reference. If the encoder does not create new - // references to entries with an absolute index lower than the draining index, the number - // of unacknowledged references to those entries will eventually become zero, allowing - // them to be evicted. - // - // <-- Newer Entries Older Entries --> - // (Larger Indices) (Smaller Indices) - // +--------+---------------------------------+----------+ - // | Unused | Referenceable | Draining | - // | Space | Entries | Entries | - // +--------+---------------------------------+----------+ - // ^ ^ ^ - // | | | - // Insertion Point Draining Index Dropping - // Point - pub(crate) fn should_index(&self, index: &Option) -> bool { - match index { - Some(TableIndex::Field(x)) => { - if *x < self.draining_index { - return false; - } - true + match searcher.search_in_dynamic(h, v, allow_block) { + TableIndex::Field(index) => { + return SearchResult::DynamicIndex(index); } - Some(TableIndex::FieldName(x)) => { - if *x < self.draining_index { - return false; + TableIndex::FieldName(index) => { + if search_result == SearchResult::NotFound { + search_result = SearchResult::DynamicNameIndex(index); } - true } - _ => true, + TableIndex::None => {} } + + search_result } } @@ -335,7 +277,7 @@ pub(crate) enum ReprEncodeState { IndexingWithName(IndexingWithName), IndexingWithPostName(IndexingWithPostName), IndexingWithLiteral(IndexingWithLiteral), - IndexedWithPostName(IndexedWithPostName), + IndexedWithPostName(IndexedWithPost), Duplicate(Duplicate), } @@ -350,14 +292,12 @@ impl SetCap { pub(crate) fn new(capacity: usize) -> Self { Self { - capacity: Integer::index(0x20, capacity, PrefixMask::SETCAP.0), + capacity: Integer::index(0x20, capacity, PrefixMask::SET_CAP.0), } } - pub(crate) fn encode(self, dst: &mut [u8]) -> Result { - self.capacity - .encode(dst) - .map_err(|e| ReprEncodeState::SetCap(SetCap::from(e))) + pub(crate) fn encode(self, dst: &mut Vec) { + self.capacity.encode(dst) } } @@ -365,6 +305,7 @@ pub(crate) struct Duplicate { index: Integer, } +#[allow(unused)] impl Duplicate { fn from(index: Integer) -> Self { Self { index } @@ -376,10 +317,8 @@ impl Duplicate { } } - fn encode(self, dst: &mut [u8]) -> Result { - self.index - .encode(dst) - .map_err(|e| ReprEncodeState::Duplicate(Duplicate::from(e))) + fn encode(self, dst: &mut Vec) { + self.index.encode(dst) } } @@ -406,32 +345,28 @@ impl Indexed { } } - fn encode(self, dst: &mut [u8]) -> Result { - self.index - .encode(dst) - .map_err(|e| ReprEncodeState::Indexed(Indexed::from(e))) + fn encode(self, dst: &mut Vec) { + self.index.encode(dst) } } -pub(crate) struct IndexedWithPostName { +pub(crate) struct IndexedWithPost { index: Integer, } -impl IndexedWithPostName { +impl IndexedWithPost { fn from(index: Integer) -> Self { Self { index } } fn new(index: usize) -> Self { Self { - index: Integer::index(0x10, index, PrefixMask::INDEXINGWITHPOSTNAME.0), + index: Integer::index(0x10, index, PrefixMask::INDEXED_WITH_POST_NAME.0), } } - fn encode(self, dst: &mut [u8]) -> Result { - self.index - .encode(dst) - .map_err(|e| ReprEncodeState::IndexedWithPostName(IndexedWithPostName::from(e))) + fn encode(self, dst: &mut Vec) { + self.index.encode(dst) } } @@ -448,22 +383,20 @@ impl InsertWithName { if is_static { Self { inner: IndexAndValue::new() - .set_index(0xc0, index, PrefixMask::INSERTWITHINDEX.0) + .set_index(0xc0, index, PrefixMask::INSERT_WITH_INDEX.0) .set_value(value, is_huffman), } } else { Self { inner: IndexAndValue::new() - .set_index(0x80, index, PrefixMask::INSERTWITHINDEX.0) + .set_index(0x80, index, PrefixMask::INSERT_WITH_INDEX.0) .set_value(value, is_huffman), } } } - fn encode(self, dst: &mut [u8]) -> Result { - self.inner - .encode(dst) - .map_err(|e| ReprEncodeState::InsertWithName(InsertWithName::from(e))) + fn encode(self, dst: &mut Vec) { + self.inner.encode(dst) } } @@ -486,31 +419,29 @@ impl IndexingWithName { match (no_permit, is_static) { (true, true) => Self { inner: IndexAndValue::new() - .set_index(0x70, index, PrefixMask::INDEXINGWITHNAME.0) + .set_index(0x70, index, PrefixMask::INDEXING_WITH_NAME.0) .set_value(value, is_huffman), }, (true, false) => Self { inner: IndexAndValue::new() - .set_index(0x60, index, PrefixMask::INDEXINGWITHNAME.0) + .set_index(0x60, index, PrefixMask::INDEXING_WITH_NAME.0) .set_value(value, is_huffman), }, (false, true) => Self { inner: IndexAndValue::new() - .set_index(0x50, index, PrefixMask::INDEXINGWITHNAME.0) + .set_index(0x50, index, PrefixMask::INDEXING_WITH_NAME.0) .set_value(value, is_huffman), }, (false, false) => Self { inner: IndexAndValue::new() - .set_index(0x40, index, PrefixMask::INDEXINGWITHNAME.0) + .set_index(0x40, index, PrefixMask::INDEXING_WITH_NAME.0) .set_value(value, is_huffman), }, } } - fn encode(self, dst: &mut [u8]) -> Result { - self.inner - .encode(dst) - .map_err(|e| ReprEncodeState::IndexingWithName(IndexingWithName::from(e))) + fn encode(self, dst: &mut Vec) { + self.inner.encode(dst) } } @@ -527,22 +458,20 @@ impl IndexingWithPostName { if no_permit { Self { inner: IndexAndValue::new() - .set_index(0x08, index, PrefixMask::INDEXINGWITHPOSTNAME.0) + .set_index(0x08, index, PrefixMask::INDEXED_WITH_POST_NAME.0) .set_value(value, is_huffman), } } else { Self { inner: IndexAndValue::new() - .set_index(0x00, index, PrefixMask::INDEXINGWITHPOSTNAME.0) + .set_index(0x00, index, PrefixMask::INDEXED_WITH_POST_NAME.0) .set_value(value, is_huffman), } } } - fn encode(self, dst: &mut [u8]) -> Result { - self.inner - .encode(dst) - .map_err(|e| ReprEncodeState::IndexingWithPostName(IndexingWithPostName::from(e))) + fn encode(self, dst: &mut Vec) { + self.inner.encode(dst) } } @@ -555,22 +484,22 @@ impl IndexingWithLiteral { match (no_permit, is_huffman) { (true, true) => Self { inner: NameAndValue::new() - .set_index(0x38, name.len(), PrefixMask::INDEXINGWITHLITERAL.0) + .set_index(0x38, name.len(), PrefixMask::INSERT_WITH_LITERAL.0) .set_name_and_value(name, value, is_huffman), }, (true, false) => Self { inner: NameAndValue::new() - .set_index(0x30, name.len(), PrefixMask::INDEXINGWITHLITERAL.0) + .set_index(0x30, name.len(), PrefixMask::INSERT_WITH_LITERAL.0) .set_name_and_value(name, value, is_huffman), }, (false, true) => Self { inner: NameAndValue::new() - .set_index(0x28, name.len(), PrefixMask::INDEXINGWITHLITERAL.0) + .set_index(0x28, name.len(), PrefixMask::INSERT_WITH_LITERAL.0) .set_name_and_value(name, value, is_huffman), }, (false, false) => Self { inner: NameAndValue::new() - .set_index(0x20, name.len(), PrefixMask::INDEXINGWITHLITERAL.0) + .set_index(0x20, name.len(), PrefixMask::INSERT_WITH_LITERAL.0) .set_name_and_value(name, value, is_huffman), }, } @@ -580,10 +509,8 @@ impl IndexingWithLiteral { Self { inner } } - fn encode(self, dst: &mut [u8]) -> Result { - self.inner - .encode(dst) - .map_err(|e| ReprEncodeState::InsertWithLiteral(InsertWithLiteral::from(e))) + fn encode(self, dst: &mut Vec) { + self.inner.encode(dst) } } @@ -596,13 +523,13 @@ impl InsertWithLiteral { if is_huffman { Self { inner: NameAndValue::new() - .set_index(0x60, name.len(), PrefixMask::INSERTWITHLITERAL.0) + .set_index(0x60, name.len(), PrefixMask::INSERT_WITH_LITERAL.0) .set_name_and_value(name, value, is_huffman), } } else { Self { inner: NameAndValue::new() - .set_index(0x40, name.len(), PrefixMask::INSERTWITHLITERAL.0) + .set_index(0x40, name.len(), PrefixMask::INSERT_WITH_LITERAL.0) .set_name_and_value(name, value, is_huffman), } } @@ -612,10 +539,8 @@ impl InsertWithLiteral { Self { inner } } - fn encode(self, dst: &mut [u8]) -> Result { - self.inner - .encode(dst) - .map_err(|e| ReprEncodeState::InsertWithLiteral(InsertWithLiteral::from(e))) + fn encode(self, dst: &mut Vec) { + self.inner.encode(dst) } } @@ -625,15 +550,9 @@ pub(crate) struct IndexAndValue { value_octets: Option, } macro_rules! check_and_encode { - ($item: expr, $dst: expr, $cur: expr, $self: expr) => {{ + ($item: expr, $dst: expr) => {{ if let Some(i) = $item.take() { - match i.encode($dst) { - Ok(len) => $cur += len, - Err(e) => { - $item = Some(e); - return Err($self); - } - }; + i.encode($dst) } }}; } @@ -652,17 +571,16 @@ impl IndexAndValue { } fn set_value(mut self, value: Vec, is_huffman: bool) -> Self { - self.value_length = Some(Integer::length(value.len(), is_huffman)); - self.value_octets = Some(Octets::new(value)); + let huffman_value = Octets::new(value, is_huffman); + self.value_length = Some(Integer::length(huffman_value.len(), is_huffman)); + self.value_octets = Some(huffman_value); self } - fn encode(mut self, dst: &mut [u8]) -> Result { - let mut cur = 0; - check_and_encode!(self.index, &mut dst[cur..], cur, self); - check_and_encode!(self.value_length, &mut dst[cur..], cur, self); - check_and_encode!(self.value_octets, &mut dst[cur..], cur, self); - Ok(cur) + fn encode(mut self, dst: &mut Vec) { + check_and_encode!(self.index, dst); + check_and_encode!(self.value_length, dst); + check_and_encode!(self.value_octets, dst); } } @@ -691,21 +609,22 @@ impl NameAndValue { } fn set_name_and_value(mut self, name: Vec, value: Vec, is_huffman: bool) -> Self { - self.name_length = Some(Integer::length(name.len(), is_huffman)); - self.name_octets = Some(Octets::new(name)); - self.value_length = Some(Integer::length(value.len(), is_huffman)); - self.value_octets = Some(Octets::new(value)); + let huffman_name = Octets::new(name, is_huffman); + self.name_length = Some(Integer::length(huffman_name.len(), is_huffman)); + self.name_octets = Some(huffman_name); + let huffman_value = Octets::new(value, is_huffman); + self.value_length = Some(Integer::length(huffman_value.len(), is_huffman)); + self.value_octets = Some(huffman_value); self } - fn encode(mut self, dst: &mut [u8]) -> Result { - let mut cur = 0; - check_and_encode!(self.index, &mut dst[cur..], cur, self); - // check_and_encode!(self.name_length, &mut dst[cur..], cur, self); //no need for qpack cause it in index. - check_and_encode!(self.name_octets, &mut dst[cur..], cur, self); - check_and_encode!(self.value_length, &mut dst[cur..], cur, self); - check_and_encode!(self.value_octets, &mut dst[cur..], cur, self); - Ok(cur) + fn encode(mut self, dst: &mut Vec) { + check_and_encode!(self.index, dst); + // check_and_encode!(self.name_length, &mut dst[cur..], cur, self); //no need + // for qpack cause it in index. + check_and_encode!(self.name_octets, dst); + check_and_encode!(self.value_length, dst); + check_and_encode!(self.value_octets, dst); } } @@ -750,7 +669,7 @@ impl<'a> DecInstDecoder<'a> { pub(crate) fn decode( &mut self, ins_state: &mut Option, - ) -> Result, H3errorQpack> { + ) -> Result, QpackError> { if self.buf.is_empty() { return Ok(None); } @@ -773,6 +692,7 @@ impl<'a> DecInstDecoder<'a> { } } } + state_def!( DecInstIndexInner, (DecoderInstPrefixBit, usize), @@ -796,14 +716,14 @@ impl DecInstIndex { DecResult::Decoded((DecoderInstPrefixBit::ACK, index)) => { DecResult::Decoded(DecoderInstruction::Ack { stream_id: index }) } - DecResult::Decoded((DecoderInstPrefixBit::STREAMCANCEL, index)) => { + DecResult::Decoded((DecoderInstPrefixBit::STREAM_CANCEL, index)) => { DecResult::Decoded(DecoderInstruction::StreamCancel { stream_id: index }) } - DecResult::Decoded((DecoderInstPrefixBit::INSERTCOUNTINCREMENT, index)) => { + DecResult::Decoded((DecoderInstPrefixBit::INSERT_COUNT_INCREMENT, index)) => { DecResult::Decoded(DecoderInstruction::InsertCountIncrement { increment: index }) } DecResult::Error(e) => e.into(), - _ => DecResult::Error(H3errorQpack::ConnectionError(DecoderStreamError)), + _ => DecResult::Error(QpackError::ConnectionError(DecoderStreamError)), } } } @@ -828,7 +748,8 @@ impl InstFirstByte { match IntegerDecoder::first_byte(byte, mask.0) { // Return the ReprPrefixBit and index part value. Ok(idx) => DecResult::Decoded((inst, idx)), - // Index part value is longer than index(i.e. use all 1 to represent), so it needs more bytes to decode. + // Index part value is longer than index(i.e. use all 1 to represent), so it needs more + // bytes to decode. Err(int) => InstTrailingBytes::new(inst, int).decode(buf), } } @@ -867,36 +788,25 @@ impl InstTrailingBytes { pub(crate) struct Octets { src: Vec, - idx: usize, } impl Octets { - fn new(src: Vec) -> Self { - Self { src, idx: 0 } + fn new(src: Vec, is_huffman: bool) -> Self { + if is_huffman { + let mut dst = Vec::with_capacity(src.len()); + huffman_encode(src.as_slice(), dst.as_mut()); + Self { src: dst } + } else { + Self { src } + } } - fn encode(mut self, dst: &mut [u8]) -> Result { - let mut cur = 0; - - let input_len = self.src.len() - self.idx; - let output_len = dst.len(); - - if input_len == 0 { - return Ok(cur); - } + fn encode(self, dst: &mut Vec) { + dst.extend_from_slice(self.src.as_slice()); + } - match output_len.cmp(&input_len) { - Ordering::Greater | Ordering::Equal => { - dst[..input_len].copy_from_slice(&self.src[self.idx..]); - cur += input_len; - Ok(cur) - } - Ordering::Less => { - dst[..].copy_from_slice(&self.src[self.idx..self.idx + output_len]); - self.idx += output_len; - Err(self) - } - } + fn len(&self) -> usize { + self.src.len() } } @@ -928,34 +838,34 @@ impl PartsIter { /// Gets headers in the order of `Method`, `Status`, `Scheme`, `Path`, /// `Authority` and `Other`. - fn next(&mut self) -> Option<(Field, String)> { + fn next(&mut self) -> Option<(NameField, String)> { loop { match self.next_type { PartsIterDirection::Method => match self.pseudo.take_method() { - Some(value) => return Some((Field::Method, value)), + Some(value) => return Some((NameField::Method, value)), None => self.next_type = PartsIterDirection::Status, }, PartsIterDirection::Status => match self.pseudo.take_status() { - Some(value) => return Some((Field::Status, value)), + Some(value) => return Some((NameField::Status, value)), None => self.next_type = PartsIterDirection::Scheme, }, PartsIterDirection::Scheme => match self.pseudo.take_scheme() { - Some(value) => return Some((Field::Scheme, value)), + Some(value) => return Some((NameField::Scheme, value)), None => self.next_type = PartsIterDirection::Path, }, PartsIterDirection::Path => match self.pseudo.take_path() { - Some(value) => return Some((Field::Path, value)), + Some(value) => return Some((NameField::Path, value)), None => self.next_type = PartsIterDirection::Authority, }, PartsIterDirection::Authority => match self.pseudo.take_authority() { - Some(value) => return Some((Field::Authority, value)), + Some(value) => return Some((NameField::Authority, value)), None => self.next_type = PartsIterDirection::Other, }, PartsIterDirection::Other => { return self .map .next() - .map(|(h, v)| (Field::Other(h.to_string()), v.to_string().unwrap())); + .map(|(h, v)| (NameField::Other(h.to_string()), v.to_string().unwrap())); } } } diff --git a/ylong_http/src/h3/qpack/integer.rs b/ylong_http/src/h3/qpack/integer.rs index e234ff1..ba5ce7f 100644 --- a/ylong_http/src/h3/qpack/integer.rs +++ b/ylong_http/src/h3/qpack/integer.rs @@ -11,11 +11,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -#![rustfmt::skip] - -use crate::h3::qpack::error::{ErrorCode, H3errorQpack}; use std::cmp::Ordering; +use crate::h3::qpack::error::{ErrorCode, QpackError}; + pub(crate) struct Integer { pub(crate) int: IntegerEncoder, } @@ -28,22 +27,18 @@ impl Integer { } pub(crate) fn length(length: usize, is_huffman: bool) -> Self { Self { - int: IntegerEncoder::new(u8::from(is_huffman), length, 0x7f), + int: IntegerEncoder::new(pre_mask(is_huffman), length, 0x7f), } } - pub(crate) fn encode(mut self, dst: &mut [u8]) -> Result { - let mut cur = 0; + pub(crate) fn encode(mut self, dst: &mut Vec) { while !self.int.is_finish() { - let dst = &mut dst[cur..]; - if dst.is_empty() { - return Err(self); + if let Some(byte) = self.int.next_byte() { + dst.push(byte) } - dst[0] = self.int.next_byte().unwrap(); - cur += 1; } - Ok(cur) } } + pub(crate) struct IntegerDecoder { index: usize, shift: u32, @@ -67,14 +62,12 @@ impl IntegerDecoder { /// Continues computing the integer based on the next byte of the input. /// Returns `Ok(Some(index))` if the result is obtained, otherwise returns /// `Ok(None)`, and returns Err in case of overflow. - pub(crate) fn next_byte(&mut self, byte: u8) -> Result, H3errorQpack> { + pub(crate) fn next_byte(&mut self, byte: u8) -> Result, QpackError> { self.index = 1usize .checked_shl(self.shift - 1) .and_then(|res| res.checked_mul((byte & 0x7f) as usize)) .and_then(|res| res.checked_add(self.index)) - .ok_or(H3errorQpack::ConnectionError( - ErrorCode::DecompressionFailed, - ))?; //todo: modify the error code + .ok_or(QpackError::ConnectionError(ErrorCode::DecompressionFailed))?; // todo: modify the error code self.shift += 7; match (byte & 0x80) == 0x00 { true => Ok(Some(self.index)), @@ -146,3 +139,11 @@ impl IntegerEncoder { matches!(self.state, IntegerEncodeState::Finish) } } + +fn pre_mask(is_huffman: bool) -> u8 { + if is_huffman { + 0x80 + } else { + 0 + } +} diff --git a/ylong_http/src/h3/qpack/mod.rs b/ylong_http/src/h3/qpack/mod.rs index df03515..9f34e45 100644 --- a/ylong_http/src/h3/qpack/mod.rs +++ b/ylong_http/src/h3/qpack/mod.rs @@ -11,19 +11,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -#![rustfmt::skip] - pub mod decoder; pub mod encoder; pub(crate) mod error; pub mod format; mod integer; pub mod table; +pub(crate) use decoder::{FieldDecodeState, FiledLines, QpackDecoder}; +pub(crate) use encoder::{DecoderInst, QpackEncoder}; + use crate::h3::qpack::format::decoder::Name; -pub(crate) use decoder::FiledLines; -pub(crate) use decoder::QpackDecoder; -pub(crate) use encoder::DecoderInst; -pub(crate) use encoder::QpackEncoder; pub(crate) struct RequireInsertCount(usize); @@ -40,15 +37,15 @@ pub(crate) struct ReprPrefixBit(u8); /// # Prefix bit: /// ## Encoder Instructions: -/// SETCAP: 0x20 -/// INSERTWITHINDEX: 0x80 -/// INSERTWITHLITERAL: 0x40 +/// SET_CAP: 0x20 +/// INSERT_WITH_INDEX: 0x80 +/// INSERT_WITH_LITERAL: 0x40 /// DUPLICATE: 0x00 /// /// ## Decoder Instructions: /// ACK: 0x80 -/// STREAMCANCEL: 0x40 -/// INSERTCOUNTINCREMENT: 0x00 +/// STREAM_CANCEL: 0x40 +/// INSERT_COUNT_INCREMENT: 0x00 /// /// ## Representation: /// INDEXED: 0x80 @@ -59,22 +56,24 @@ pub(crate) struct ReprPrefixBit(u8); impl DecoderInstPrefixBit { pub(crate) const ACK: Self = Self(0x80); - pub(crate) const STREAMCANCEL: Self = Self(0x40); - pub(crate) const INSERTCOUNTINCREMENT: Self = Self(0x00); + pub(crate) const STREAM_CANCEL: Self = Self(0x40); + pub(crate) const INSERT_COUNT_INCREMENT: Self = Self(0x00); pub(crate) fn from_u8(byte: u8) -> Self { match byte { - x if x >= 0x80 => Self::ACK, - x if x >= 0x40 => Self::STREAMCANCEL, - _ => Self::INSERTCOUNTINCREMENT, + x if x & 0x80 == 0x80 => Self::ACK, + x if x & 0xC0 == 0x40 => Self::STREAM_CANCEL, + x if x & 0xC0 == 0x0 => Self::INSERT_COUNT_INCREMENT, + _ => unreachable!(), } } pub(crate) fn prefix_index_mask(&self) -> PrefixMask { match self.0 { 0x80 => PrefixMask::ACK, - 0x40 => PrefixMask::STREAMCANCEL, - _ => PrefixMask::INSERTCOUNTINCREMENT, + 0x40 => PrefixMask::STREAM_CANCEL, + 0x0 => PrefixMask::INSERT_COUNT_INCREMENT, + _ => unreachable!(), } } @@ -88,25 +87,25 @@ impl DecoderInstPrefixBit { } impl EncoderInstPrefixBit { - pub(crate) const SETCAP: Self = Self(0x20); - pub(crate) const INSERTWITHINDEX: Self = Self(0x80); - pub(crate) const INSERTWITHLITERAL: Self = Self(0x40); + pub(crate) const SET_CAP: Self = Self(0x20); + pub(crate) const INSERT_WITH_INDEX: Self = Self(0x80); + pub(crate) const INSERT_WITH_LITERAL: Self = Self(0x40); pub(crate) const DUPLICATE: Self = Self(0x00); pub(crate) fn from_u8(byte: u8) -> Self { match byte { - x if x >= 0x80 => Self::INSERTWITHINDEX, - x if x >= 0x40 => Self::INSERTWITHLITERAL, - x if x >= 0x20 => Self::SETCAP, + x if x >= 0x80 => Self::INSERT_WITH_INDEX, + x if x >= 0x40 => Self::INSERT_WITH_LITERAL, + x if x >= 0x20 => Self::SET_CAP, _ => Self::DUPLICATE, } } pub(crate) fn prefix_index_mask(&self) -> PrefixMask { match self.0 { - 0x80 => PrefixMask::INSERTWITHINDEX, - 0x40 => PrefixMask::INSERTWITHLITERAL, - 0x20 => PrefixMask::SETCAP, + 0x80 => PrefixMask::INSERT_WITH_INDEX, + 0x40 => PrefixMask::INSERT_WITH_LITERAL, + 0x20 => PrefixMask::SET_CAP, _ => PrefixMask::DUPLICATE, } } @@ -138,6 +137,7 @@ impl EncoderInstPrefixBit { } impl ReprPrefixBit { + // 此处的值为前缀1的位置,并没有实际意义 pub(crate) const INDEXED: Self = Self(0x80); pub(crate) const INDEXEDWITHPOSTINDEX: Self = Self(0x10); pub(crate) const LITERALWITHINDEXING: Self = Self(0x40); @@ -161,15 +161,16 @@ impl ReprPrefixBit { pub(crate) fn prefix_index_mask(&self) -> PrefixMask { match self.0 { 0x80 => PrefixMask::INDEXED, - 0x40 => PrefixMask::INDEXINGWITHNAME, - 0x20 => PrefixMask::INDEXINGWITHLITERAL, - 0x10 => PrefixMask::INDEXEDWITHPOSTNAME, - _ => PrefixMask::INDEXINGWITHPOSTNAME, + 0x40 => PrefixMask::INDEXING_WITH_NAME, + 0x20 => PrefixMask::INDEXING_WITH_LITERAL, + 0x10 => PrefixMask::INDEXED_WITH_POST_NAME, + _ => PrefixMask::INDEXING_WITH_LITERAL, } } - /// Unlike Hpack, QPACK has some special value for the first byte of an integer. - /// Like T indicating whether the reference is into the static or dynamic table. + /// Unlike Hpack, QPACK has some special value for the first byte of an + /// integer. Like T indicating whether the reference is into the static + /// or dynamic table. pub(crate) fn prefix_midbit_value(&self, byte: u8) -> MidBit { match self.0 { 0x80 => MidBit { @@ -227,17 +228,18 @@ pub(crate) enum DecoderInstruction { } pub(crate) enum Representation { - /// An indexed field line format identifies an entry in the static table or an entry in - /// the dynamic table with an absolute index less than the value of the Base. - /// 0 1 2 3 4 5 6 7 + /// An indexed field line format identifies an entry in the static table or + /// an entry in the dynamic table with an absolute index less than the + /// value of the Base. 0 1 2 3 4 5 6 7 /// +---+---+---+---+---+---+---+---+ /// | 1 | T | Index (6+) | /// +---+---+-----------------------+ - /// This format starts with the '1' 1-bit pattern, followed by the 'T' bit, indicating - /// whether the reference is into the static or dynamic table. The 6-bit prefix integer - /// (Section 4.1.1) that follows is used to locate the table entry for the field line. When T=1, - /// the number represents the static table index; when T=0, the number is the relative index of - /// the entry in the dynamic table. + /// This format starts with the '1' 1-bit pattern, followed by the 'T' bit, + /// indicating whether the reference is into the static or dynamic + /// table. The 6-bit prefix integer (Section 4.1.1) that follows is used + /// to locate the table entry for the field line. When T=1, the number + /// represents the static table index; when T=0, the number is the relative + /// index of the entry in the dynamic table. FieldSectionPrefix { require_insert_count: RequireInsertCount, signal: bool, @@ -268,7 +270,7 @@ pub(crate) enum Representation { }, } -//impl debug for Representation +// impl debug for Representation pub(crate) struct MidBit { //'N', indicates whether an intermediary is permitted to add this field line to the dynamic @@ -283,20 +285,18 @@ pub(crate) struct MidBit { pub(crate) struct PrefixMask(u8); impl PrefixMask { - pub(crate) const REQUIREINSERTCOUNT: Self = Self(0xff); - pub(crate) const DELTABASE: Self = Self(0x7f); + pub(crate) const REQUIRE_INSERT_COUNT: Self = Self(0xff); + pub(crate) const DELTA_BASE: Self = Self(0x7f); pub(crate) const INDEXED: Self = Self(0x3f); - pub(crate) const SETCAP: Self = Self(0x1f); - pub(crate) const INSERTWITHINDEX: Self = Self(0x3f); - pub(crate) const INSERTWITHLITERAL: Self = Self(0x1f); + pub(crate) const SET_CAP: Self = Self(0x1f); + pub(crate) const INSERT_WITH_INDEX: Self = Self(0x3f); + pub(crate) const INSERT_WITH_LITERAL: Self = Self(0x1f); pub(crate) const DUPLICATE: Self = Self(0x1f); - pub(crate) const ACK: Self = Self(0x7f); - pub(crate) const STREAMCANCEL: Self = Self(0x3f); - pub(crate) const INSERTCOUNTINCREMENT: Self = Self(0x3f); - - pub(crate) const INDEXINGWITHNAME: Self = Self(0x0f); - pub(crate) const INDEXINGWITHPOSTNAME: Self = Self(0x07); - pub(crate) const INDEXINGWITHLITERAL: Self = Self(0x07); - pub(crate) const INDEXEDWITHPOSTNAME: Self = Self(0x0f); + pub(crate) const STREAM_CANCEL: Self = Self(0x3f); + pub(crate) const INSERT_COUNT_INCREMENT: Self = Self(0x3f); + pub(crate) const INDEXING_WITH_NAME: Self = Self(0x0f); + pub(crate) const INDEXING_WITH_POST_NAME: Self = Self(0x07); + pub(crate) const INDEXING_WITH_LITERAL: Self = Self(0x07); + pub(crate) const INDEXED_WITH_POST_NAME: Self = Self(0x0f); } diff --git a/ylong_http/src/h3/qpack/table.rs b/ylong_http/src/h3/qpack/table.rs index faef2a9..c5f9717 100644 --- a/ylong_http/src/h3/qpack/table.rs +++ b/ylong_http/src/h3/qpack/table.rs @@ -11,22 +11,23 @@ // See the License for the specific language governing permissions and // limitations under the License. -#![rustfmt::skip] - use std::collections::{HashMap, VecDeque}; +use crate::h3::qpack::error::QpackError; + /// The [`Dynamic Table`][dynamic_table] implementation of [QPACK]. /// /// [dynamic_table]: https://www.rfc-editor.org/rfc/rfc9204.html#name-dynamic-table /// [QPACK]: https://www.rfc-editor.org/rfc/rfc9204.html /// # Introduction -/// The dynamic table consists of a list of field lines maintained in first-in, first-out order. -/// A QPACK encoder and decoder share a dynamic table that is initially empty. -/// The encoder adds entries to the dynamic table and sends them to the decoder via instructions on -/// the encoder stream +/// The dynamic table consists of a list of field lines maintained in first-in, +/// first-out order. A QPACK encoder and decoder share a dynamic table that is +/// initially empty. The encoder adds entries to the dynamic table and sends +/// them to the decoder via instructions on the encoder stream /// -/// The dynamic table can contain duplicate entries (i.e., entries with the same name and same value). -/// Therefore, duplicate entries MUST NOT be treated as an error by the decoder. +/// The dynamic table can contain duplicate entries (i.e., entries with the same +/// name and same value). Therefore, duplicate entries MUST NOT be treated as an +/// error by the decoder. /// /// Dynamic table entries can have empty values. @@ -34,79 +35,98 @@ pub(crate) struct TableSearcher<'a> { dynamic: &'a DynamicTable, } +#[derive(Eq, PartialEq, Copy, Clone)] +pub(crate) enum SearchResult { + StaticIndex(usize), + StaticNameIndex(usize), + DynamicIndex(usize), + DynamicNameIndex(usize), + NotFound, +} + impl<'a> TableSearcher<'a> { pub(crate) fn new(dynamic: &'a DynamicTable) -> Self { Self { dynamic } } - /// Searches index in static and dynamic tables. - pub(crate) fn find_index_static(&self, header: &Field, value: &str) -> Option { - match StaticTable::index(header, value) { - x @ Some(TableIndex::Field(_)) => x, - _ => Some(TableIndex::None), - } + pub(crate) fn search_in_static(&self, header: &NameField, value: &str) -> TableIndex { + StaticTable::index(header, value) } - pub(crate) fn find_index_name_static(&self, header: &Field, value: &str) -> Option { - match StaticTable::index(header, value) { - x @ Some(TableIndex::FieldName(_)) => x, - _ => Some(TableIndex::None), - } - } - - pub(crate) fn find_index_dynamic(&self, header: &Field, value: &str) -> Option { - match self.dynamic.index(header, value) { - x @ Some(TableIndex::Field(_)) => x, - _ => Some(TableIndex::None), - } - } - - pub(crate) fn find_index_name_dynamic( + pub(crate) fn search_in_dynamic( &self, - header: &Field, + header: &NameField, value: &str, - ) -> Option { - match self.dynamic.index_name(header, value) { - x @ Some(TableIndex::FieldName(_)) => x, - _ => Some(TableIndex::None), - } + allow_block: bool, + ) -> TableIndex { + self.dynamic.index(header, value, allow_block) } - pub(crate) fn find_field_static(&self, index: usize) -> Option<(Field, String)> { + pub(crate) fn find_field_static(&self, index: usize) -> Option<(NameField, String)> { match StaticTable::field(index) { x @ Some((_, _)) => x, _ => None, } } - pub(crate) fn find_field_name_static(&self, index: usize) -> Option { + pub(crate) fn find_field_name_static(&self, index: usize) -> Option { StaticTable::field_name(index) } - pub(crate) fn find_field_dynamic(&self, index: usize) -> Option<(Field, String)> { + pub(crate) fn find_field_dynamic(&self, index: usize) -> Option<(NameField, String)> { self.dynamic.field(index) } - pub(crate) fn find_field_name_dynamic(&self, index: usize) -> Option { + pub(crate) fn find_field_name_dynamic(&self, index: usize) -> Option { self.dynamic.field_name(index) } } +#[derive(Clone, Eq, PartialEq)] +struct DynamicField { + index: u64, + name: NameField, + value: String, + tracked: usize, +} + +impl DynamicField { + pub(crate) fn index(&self) -> u64 { + self.index + } + + pub(crate) fn name(&self) -> &NameField { + &self.name + } + + pub(crate) fn value(&self) -> &str { + self.value.as_str() + } + + pub(crate) fn size(&self) -> usize { + self.name.len() + self.value.len() + 32 + } + + pub(crate) fn is_tracked(&self) -> bool { + self.tracked > 0 + } +} + pub struct DynamicTable { - queue: VecDeque<(Field, String)>, - // The size of the dynamic table is the sum of the size of its entries - size: usize, + queue: VecDeque, + // The used_cap of the dynamic table is the sum of the used_cap of its entries + used_cap: usize, capacity: usize, - pub(crate) insert_count: usize, + insert_count: usize, remove_count: usize, - pub(crate) known_received_count: usize, + known_received_count: u64, } impl DynamicTable { pub fn with_empty() -> Self { Self { queue: VecDeque::new(), - size: 0, + used_cap: 0, capacity: 0, insert_count: 0, remove_count: 0, @@ -114,51 +134,146 @@ impl DynamicTable { } } + pub(crate) fn update_capacity(&mut self, new_cap: usize) -> Option { + let mut updated = None; + if new_cap < self.capacity { + let required = self.capacity - new_cap; + if let Some(size) = self.can_evict(required) { + self.evict_drained(size); + self.capacity = new_cap; + updated = Some(new_cap); + } + } else { + self.capacity = new_cap; + updated = Some(new_cap); + } + updated + } + + pub(crate) fn insert_count(&self) -> usize { + self.insert_count + } + + pub(crate) fn track_field(&mut self, index: usize) { + if let Some(field) = self.queue.get_mut(index - self.remove_count) { + field.tracked += 1; + } else { + unreachable!() + } + } + + pub(crate) fn untracked_field(&mut self, index: usize) { + if let Some(field) = self.queue.get_mut(index - self.remove_count) { + assert!(field.tracked > 0); + field.tracked -= 1; + } else { + unreachable!() + } + } + + pub(crate) fn increase_known_receive_count(&mut self, increment: usize) { + self.known_received_count += (increment - 1) as u64; + // TODO 替换成error + assert!(self.known_received_count < (self.insert_count as u64)) + } + pub(crate) fn size(&self) -> usize { - self.size + self.used_cap } pub(crate) fn capacity(&self) -> usize { self.capacity } + pub(crate) fn can_evict(&mut self, required: usize) -> Option { + if required > self.capacity { + return None; + } + let bound = self.capacity - required; + let mut can_evict = 0; + let mut used_cap = self.used_cap; + while !self.queue.is_empty() && used_cap > bound { + if let Some(to_evict) = self.queue.front() { + if to_evict.is_tracked() || to_evict.index() > self.known_recved_count() { + return None; + } + used_cap -= to_evict.size(); + can_evict += 1; + } + } + Some(can_evict) + } + + // Note ensure that there are enough entries in the queue before the expulsion. + pub(crate) fn evict_drained(&mut self, size: usize) { + let mut to_evict = size; + while to_evict > 0 { + if let Some(field) = self.queue.pop_front() { + self.used_cap -= field.size(); + self.remove_count += 1; + } else { + unreachable!() + } + to_evict -= 1; + } + } + + pub(crate) fn known_recved_count(&self) -> u64 { + self.known_received_count + } + pub(crate) fn max_entries(&self) -> usize { self.capacity / 32 } + /// Updates `DynamicTable` by a given `Header` and value pair. - pub(crate) fn update(&mut self, field: Field, value: String) -> Option { + pub(crate) fn update(&mut self, field: NameField, value: String) -> usize { + let field = DynamicField { + index: self.insert_count() as u64, + name: field, + value, + tracked: 0, + }; + let size = field.size(); + let index = field.index(); + self.queue.push_back(field); self.insert_count += 1; - self.size += field.len() + value.len() + 32; - self.queue.push_back((field.clone(), value.clone())); - self.fit_size(); - self.index(&field, &value) + self.used_cap += size; + index as usize } - pub(crate) fn have_enough_space( - &self, - field: &Field, - value: &String, - insert_length: &usize, - ) -> bool { - if self.size + field.len() + value.len() + 32 <= self.capacity - insert_length { - return true; - } else { - let mut eviction_space = 0; - for (i, (h, v)) in self.queue.iter().enumerate() { - if i <= self.known_received_count { - eviction_space += h.len() + v.len() + 32; + /// Tries to get the index of a `Header`. + fn index(&self, header: &NameField, value: &str, allow_block: bool) -> TableIndex { + let mut index = TableIndex::None; + + for field in self.queue.iter() { + // 从queue的头开始迭代,index从小到大,找到最新(大)的index + if field.index() > self.known_recved_count() && !allow_block { + break; + } + if header == field.name() { + // find latest then return + index = if value == field.value() { + TableIndex::Field(field.index() as usize) } else { - if eviction_space - insert_length >= field.len() + value.len() + 32 { - return true; - } - return false; - } - if eviction_space - insert_length >= field.len() + value.len() + 32 { - return true; + TableIndex::FieldName(field.index() as usize) } } } - false + index + } + + pub(crate) fn field(&self, index: usize) -> Option<(NameField, String)> { + self.queue + .get(index - self.remove_count) + .cloned() + .map(|field| (field.name, field.value)) + } + + pub(crate) fn field_name(&self, index: usize) -> Option { + self.queue + .get(index - self.remove_count) + .map(|field| field.name().clone()) } /// Updates `DynamicTable`'s size. @@ -169,48 +284,15 @@ impl DynamicTable { /// Adjusts dynamic table content to fit its size. fn fit_size(&mut self) { - while self.size > self.capacity && !self.queue.is_empty() { - let (key, string) = self.queue.pop_front().unwrap(); + while self.used_cap > self.capacity && !self.queue.is_empty() { + let field = self.queue.pop_front().unwrap(); self.remove_count += 1; - self.capacity -= key.len() + string.len() + 32; + self.used_cap -= field.size(); } } - - /// Tries get the index of a `Header`. - fn index(&self, header: &Field, value: &str) -> Option { - // find latest - let mut index = None; - for (n, (h, v)) in self.queue.iter().enumerate() { - if let (true, true, _) = (header == h, value == v, &index) { - index = Some(TableIndex::Field(n + self.remove_count)) - } - } - index - } - - fn index_name(&self, header: &Field, value: &str) -> Option { - // find latest - let mut index = None; - for (n, (h, v)) in self.queue.iter().enumerate() { - if let (true, _, _) = (header == h, value == v, &index) { - index = Some(TableIndex::FieldName(n + self.remove_count)) - } - } - index - } - - pub(crate) fn field(&self, index: usize) -> Option<(Field, String)> { - self.queue.get(index - self.remove_count).cloned() - } - - pub(crate) fn field_name(&self, index: usize) -> Option { - self.queue - .get(index - self.remove_count) - .map(|(field, _)| field.clone()) - } } -#[derive(PartialEq, Clone)] +#[derive(PartialEq, Copy, Clone)] pub(crate) enum TableIndex { Field(usize), FieldName(usize), @@ -235,459 +317,479 @@ pub(crate) enum TableIndex { /// as a connection error of type QpackDecompressionFailed. /// If this index is received on the encoder stream, /// this MUST be treated as a connection error of type QpackEncoderStreamError. -/// struct StaticTable; impl StaticTable { /// Gets a `Field` by the given index. - fn field_name(index: usize) -> Option { + fn field_name(index: usize) -> Option { match index { - 0 => Some(Field::Authority), - 1 => Some(Field::Path), - 2 => Some(Field::Other(String::from("age"))), - 3 => Some(Field::Other(String::from("content-disposition"))), - 4 => Some(Field::Other(String::from("content-length"))), - 5 => Some(Field::Other(String::from("cookie"))), - 6 => Some(Field::Other(String::from("date"))), - 7 => Some(Field::Other(String::from("etag"))), - 8 => Some(Field::Other(String::from("if-modified-since"))), - 9 => Some(Field::Other(String::from("if-none-match"))), - 10 => Some(Field::Other(String::from("last-modified"))), - 11 => Some(Field::Other(String::from("link"))), - 12 => Some(Field::Other(String::from("location"))), - 13 => Some(Field::Other(String::from("referer"))), - 14 => Some(Field::Other(String::from("set-cookie"))), - 15..=21 => Some(Field::Method), - 22..=23 => Some(Field::Scheme), - 24..=28 => Some(Field::Status), - 29..=30 => Some(Field::Other(String::from("accept"))), - 31 => Some(Field::Other(String::from("accept-encoding"))), - 32 => Some(Field::Other(String::from("accept-ranges"))), - 33..=34 => Some(Field::Other(String::from("access-control-allow-headers"))), - 35 => Some(Field::Other(String::from("access-control-allow-origin"))), - 36..=41 => Some(Field::Other(String::from("cache-control"))), - 42..=43 => Some(Field::Other(String::from("content-encoding"))), - 44..=54 => Some(Field::Other(String::from("content-type"))), - 55 => Some(Field::Other(String::from("range"))), - 56..=58 => Some(Field::Other(String::from("strict-transport-security"))), - 59..=60 => Some(Field::Other(String::from("vary"))), - 61 => Some(Field::Other(String::from("x-content-type-options"))), - 62 => Some(Field::Other(String::from("x-xss-protection"))), - 63..=71 => Some(Field::Status), - 72 => Some(Field::Other(String::from("accept-language"))), - 73..=74 => Some(Field::Other(String::from( + 0 => Some(NameField::Authority), + 1 => Some(NameField::Path), + 2 => Some(NameField::Other(String::from("age"))), + 3 => Some(NameField::Other(String::from("content-disposition"))), + 4 => Some(NameField::Other(String::from("content-length"))), + 5 => Some(NameField::Other(String::from("cookie"))), + 6 => Some(NameField::Other(String::from("date"))), + 7 => Some(NameField::Other(String::from("etag"))), + 8 => Some(NameField::Other(String::from("if-modified-since"))), + 9 => Some(NameField::Other(String::from("if-none-match"))), + 10 => Some(NameField::Other(String::from("last-modified"))), + 11 => Some(NameField::Other(String::from("link"))), + 12 => Some(NameField::Other(String::from("location"))), + 13 => Some(NameField::Other(String::from("referer"))), + 14 => Some(NameField::Other(String::from("set-cookie"))), + 15..=21 => Some(NameField::Method), + 22..=23 => Some(NameField::Scheme), + 24..=28 => Some(NameField::Status), + 29..=30 => Some(NameField::Other(String::from("accept"))), + 31 => Some(NameField::Other(String::from("accept-encoding"))), + 32 => Some(NameField::Other(String::from("accept-ranges"))), + 33..=34 => Some(NameField::Other(String::from( + "access-control-allow-headers", + ))), + 35 => Some(NameField::Other(String::from( + "access-control-allow-origin", + ))), + 36..=41 => Some(NameField::Other(String::from("cache-control"))), + 42..=43 => Some(NameField::Other(String::from("content-encoding"))), + 44..=54 => Some(NameField::Other(String::from("content-type"))), + 55 => Some(NameField::Other(String::from("range"))), + 56..=58 => Some(NameField::Other(String::from("strict-transport-security"))), + 59..=60 => Some(NameField::Other(String::from("vary"))), + 61 => Some(NameField::Other(String::from("x-content-type-options"))), + 62 => Some(NameField::Other(String::from("x-xss-protection"))), + 63..=71 => Some(NameField::Status), + 72 => Some(NameField::Other(String::from("accept-language"))), + 73..=74 => Some(NameField::Other(String::from( "access-control-allow-credentials", ))), - 75 => Some(Field::Other(String::from("access-control-allow-headers"))), - 76..=78 => Some(Field::Other(String::from("access-control-allow-methods"))), - 79 => Some(Field::Other(String::from("access-control-expose-headers"))), - 80 => Some(Field::Other(String::from("access-control-request-headers"))), - 81..=82 => Some(Field::Other(String::from("access-control-request-method"))), - 83 => Some(Field::Other(String::from("alt-svc"))), - 84 => Some(Field::Other(String::from("authorization"))), - 85 => Some(Field::Other(String::from("content-security-policy"))), - 86 => Some(Field::Other(String::from("early-data"))), - 87 => Some(Field::Other(String::from("expect-ct"))), - 88 => Some(Field::Other(String::from("forwarded"))), - 89 => Some(Field::Other(String::from("if-range"))), - 90 => Some(Field::Other(String::from("origin"))), - 91 => Some(Field::Other(String::from("purpose"))), - 92 => Some(Field::Other(String::from("server"))), - 93 => Some(Field::Other(String::from("timing-allow-origin"))), - 94 => Some(Field::Other(String::from("upgrade-insecure-requests"))), - 95 => Some(Field::Other(String::from("user-agent"))), - 96 => Some(Field::Other(String::from("x-forwarded-for"))), - 97..=98 => Some(Field::Other(String::from("x-frame-options"))), + 75 => Some(NameField::Other(String::from( + "access-control-allow-headers", + ))), + 76..=78 => Some(NameField::Other(String::from( + "access-control-allow-methods", + ))), + 79 => Some(NameField::Other(String::from( + "access-control-expose-headers", + ))), + 80 => Some(NameField::Other(String::from( + "access-control-request-headers", + ))), + 81..=82 => Some(NameField::Other(String::from( + "access-control-request-method", + ))), + 83 => Some(NameField::Other(String::from("alt-svc"))), + 84 => Some(NameField::Other(String::from("authorization"))), + 85 => Some(NameField::Other(String::from("content-security-policy"))), + 86 => Some(NameField::Other(String::from("early-data"))), + 87 => Some(NameField::Other(String::from("expect-ct"))), + 88 => Some(NameField::Other(String::from("forwarded"))), + 89 => Some(NameField::Other(String::from("if-range"))), + 90 => Some(NameField::Other(String::from("origin"))), + 91 => Some(NameField::Other(String::from("purpose"))), + 92 => Some(NameField::Other(String::from("server"))), + 93 => Some(NameField::Other(String::from("timing-allow-origin"))), + 94 => Some(NameField::Other(String::from("upgrade-insecure-requests"))), + 95 => Some(NameField::Other(String::from("user-agent"))), + 96 => Some(NameField::Other(String::from("x-forwarded-for"))), + 97..=98 => Some(NameField::Other(String::from("x-frame-options"))), _ => None, } } /// Tries to get a `Field` and a value by the given index. - fn field(index: usize) -> Option<(Field, String)> { + fn field(index: usize) -> Option<(NameField, String)> { match index { - 1 => Some((Field::Path, String::from("/"))), - 2 => Some((Field::Other(String::from("age")), String::from("0"))), + 1 => Some((NameField::Path, String::from("/"))), + 2 => Some((NameField::Other(String::from("age")), String::from("0"))), 4 => Some(( - Field::Other(String::from("content-length")), + NameField::Other(String::from("content-length")), String::from("0"), )), - 15 => Some((Field::Method, String::from("CONNECT"))), - 16 => Some((Field::Method, String::from("DELETE"))), - 17 => Some((Field::Method, String::from("GET"))), - 18 => Some((Field::Method, String::from("HEAD"))), - 19 => Some((Field::Method, String::from("OPTIONS"))), - 20 => Some((Field::Method, String::from("POST"))), - 21 => Some((Field::Method, String::from("PUT"))), - 22 => Some((Field::Scheme, String::from("http"))), - 23 => Some((Field::Scheme, String::from("https"))), - 24 => Some((Field::Status, String::from("103"))), - 25 => Some((Field::Status, String::from("200"))), - 26 => Some((Field::Status, String::from("304"))), - 27 => Some((Field::Status, String::from("404"))), - 28 => Some((Field::Status, String::from("503"))), - 29 => Some((Field::Other(String::from("accept")), String::from("*/*"))), + 15 => Some((NameField::Method, String::from("CONNECT"))), + 16 => Some((NameField::Method, String::from("DELETE"))), + 17 => Some((NameField::Method, String::from("GET"))), + 18 => Some((NameField::Method, String::from("HEAD"))), + 19 => Some((NameField::Method, String::from("OPTIONS"))), + 20 => Some((NameField::Method, String::from("POST"))), + 21 => Some((NameField::Method, String::from("PUT"))), + 22 => Some((NameField::Scheme, String::from("http"))), + 23 => Some((NameField::Scheme, String::from("https"))), + 24 => Some((NameField::Status, String::from("103"))), + 25 => Some((NameField::Status, String::from("200"))), + 26 => Some((NameField::Status, String::from("304"))), + 27 => Some((NameField::Status, String::from("404"))), + 28 => Some((NameField::Status, String::from("503"))), + 29 => Some(( + NameField::Other(String::from("accept")), + String::from("*/*"), + )), 30 => Some(( - Field::Other(String::from("accept")), + NameField::Other(String::from("accept")), String::from("application/dns-message"), )), 31 => Some(( - Field::Other(String::from("accept-encoding")), + NameField::Other(String::from("accept-encoding")), String::from("gzip, deflate, br"), )), 32 => Some(( - Field::Other(String::from("accept-ranges")), + NameField::Other(String::from("accept-ranges")), String::from("bytes"), )), 33 => Some(( - Field::Other(String::from("access-control-allow-headers")), + NameField::Other(String::from("access-control-allow-headers")), String::from("cache-control"), )), 34 => Some(( - Field::Other(String::from("access-control-allow-headers")), + NameField::Other(String::from("access-control-allow-headers")), String::from("content-type"), )), 35 => Some(( - Field::Other(String::from("access-control-allow-origin")), + NameField::Other(String::from("access-control-allow-origin")), String::from("*"), )), 36 => Some(( - Field::Other(String::from("cache-control")), + NameField::Other(String::from("cache-control")), String::from("max-age=0"), )), 37 => Some(( - Field::Other(String::from("cache-control")), + NameField::Other(String::from("cache-control")), String::from("max-age=2592000"), )), 38 => Some(( - Field::Other(String::from("cache-control")), + NameField::Other(String::from("cache-control")), String::from("max-age=604800"), )), 39 => Some(( - Field::Other(String::from("cache-control")), + NameField::Other(String::from("cache-control")), String::from("no-cache"), )), 40 => Some(( - Field::Other(String::from("cache-control")), + NameField::Other(String::from("cache-control")), String::from("no-store"), )), 41 => Some(( - Field::Other(String::from("cache-control")), + NameField::Other(String::from("cache-control")), String::from("public, max-age=31536000"), )), 42 => Some(( - Field::Other(String::from("content-encoding")), + NameField::Other(String::from("content-encoding")), String::from("br"), )), 43 => Some(( - Field::Other(String::from("content-encoding")), + NameField::Other(String::from("content-encoding")), String::from("gzip"), )), 44 => Some(( - Field::Other(String::from("content-type")), + NameField::Other(String::from("content-type")), String::from("application/dns-message"), )), 45 => Some(( - Field::Other(String::from("content-type")), + NameField::Other(String::from("content-type")), String::from("application/javascript"), )), 46 => Some(( - Field::Other(String::from("content-type")), + NameField::Other(String::from("content-type")), String::from("application/json"), )), 47 => Some(( - Field::Other(String::from("content-type")), + NameField::Other(String::from("content-type")), String::from("application/x-www-form-urlencoded"), )), 48 => Some(( - Field::Other(String::from("content-type")), + NameField::Other(String::from("content-type")), String::from("image/gif"), )), 49 => Some(( - Field::Other(String::from("content-type")), + NameField::Other(String::from("content-type")), String::from("image/jpeg"), )), 50 => Some(( - Field::Other(String::from("content-type")), + NameField::Other(String::from("content-type")), String::from("image/png"), )), 51 => Some(( - Field::Other(String::from("content-type")), + NameField::Other(String::from("content-type")), String::from("text/css"), )), 52 => Some(( - Field::Other(String::from("content-type")), + NameField::Other(String::from("content-type")), String::from("text/html; charset=utf-8"), )), 53 => Some(( - Field::Other(String::from("content-type")), + NameField::Other(String::from("content-type")), String::from("text/plain"), )), 54 => Some(( - Field::Other(String::from("content-type")), + NameField::Other(String::from("content-type")), String::from("text/plain;charset=utf-8"), )), 55 => Some(( - Field::Other(String::from("range")), + NameField::Other(String::from("range")), String::from("bytes=0-"), )), 56 => Some(( - Field::Other(String::from("strict-transport-security")), + NameField::Other(String::from("strict-transport-security")), String::from("max-age=31536000"), )), 57 => Some(( - Field::Other(String::from("strict-transport-security")), + NameField::Other(String::from("strict-transport-security")), String::from("max-age=31536000; includesubdomains"), )), 58 => Some(( - Field::Other(String::from("strict-transport-security")), + NameField::Other(String::from("strict-transport-security")), String::from("max-age=31536000; includesubdomains; preload"), )), 59 => Some(( - Field::Other(String::from("vary")), + NameField::Other(String::from("vary")), String::from("accept-encoding"), )), - 60 => Some((Field::Other(String::from("vary")), String::from("origin"))), + 60 => Some(( + NameField::Other(String::from("vary")), + String::from("origin"), + )), 61 => Some(( - Field::Other(String::from("x-content-type-options")), + NameField::Other(String::from("x-content-type-options")), String::from("nosniff"), )), 62 => Some(( - Field::Other(String::from("x-xss-protection")), + NameField::Other(String::from("x-xss-protection")), String::from("1; mode=block"), )), - 63 => Some((Field::Status, String::from("100"))), - 64 => Some((Field::Status, String::from("204"))), - 65 => Some((Field::Status, String::from("206"))), - 66 => Some((Field::Status, String::from("302"))), - 67 => Some((Field::Status, String::from("400"))), - 68 => Some((Field::Status, String::from("403"))), - 69 => Some((Field::Status, String::from("421"))), - 70 => Some((Field::Status, String::from("425"))), - 71 => Some((Field::Status, String::from("500"))), + 63 => Some((NameField::Status, String::from("100"))), + 64 => Some((NameField::Status, String::from("204"))), + 65 => Some((NameField::Status, String::from("206"))), + 66 => Some((NameField::Status, String::from("302"))), + 67 => Some((NameField::Status, String::from("400"))), + 68 => Some((NameField::Status, String::from("403"))), + 69 => Some((NameField::Status, String::from("421"))), + 70 => Some((NameField::Status, String::from("425"))), + 71 => Some((NameField::Status, String::from("500"))), 73 => Some(( - Field::Other(String::from("access-control-allow-credentials")), + NameField::Other(String::from("access-control-allow-credentials")), String::from("FALSE"), )), 74 => Some(( - Field::Other(String::from("access-control-allow-credentials")), + NameField::Other(String::from("access-control-allow-credentials")), String::from("TRUE"), )), 75 => Some(( - Field::Other(String::from("access-control-allow-headers")), + NameField::Other(String::from("access-control-allow-headers")), String::from("*"), )), 76 => Some(( - Field::Other(String::from("access-control-allow-methods")), + NameField::Other(String::from("access-control-allow-methods")), String::from("get"), )), 77 => Some(( - Field::Other(String::from("access-control-allow-methods")), + NameField::Other(String::from("access-control-allow-methods")), String::from("get, post, options"), )), 78 => Some(( - Field::Other(String::from("access-control-allow-methods")), + NameField::Other(String::from("access-control-allow-methods")), String::from("options"), )), 79 => Some(( - Field::Other(String::from("access-control-expose-headers")), + NameField::Other(String::from("access-control-expose-headers")), String::from("content-length"), )), 80 => Some(( - Field::Other(String::from("access-control-request-headers")), + NameField::Other(String::from("access-control-request-headers")), String::from("content-type"), )), 81 => Some(( - Field::Other(String::from("access-control-request-method")), + NameField::Other(String::from("access-control-request-method")), String::from("get"), )), 82 => Some(( - Field::Other(String::from("access-control-request-method")), + NameField::Other(String::from("access-control-request-method")), String::from("post"), )), - 83 => Some((Field::Other(String::from("alt-svc")), String::from("clear"))), + 83 => Some(( + NameField::Other(String::from("alt-svc")), + String::from("clear"), + )), 85 => Some(( - Field::Other(String::from("content-security-policy")), + NameField::Other(String::from("content-security-policy")), String::from("script-src 'none'; object-src 'none'; base-uri 'none'"), )), - 86 => Some((Field::Other(String::from("early-data")), String::from("1"))), + 86 => Some(( + NameField::Other(String::from("early-data")), + String::from("1"), + )), 91 => Some(( - Field::Other(String::from("purpose")), + NameField::Other(String::from("purpose")), String::from("prefetch"), )), 93 => Some(( - Field::Other(String::from("timing-allow-origin")), + NameField::Other(String::from("timing-allow-origin")), String::from("*"), )), 94 => Some(( - Field::Other(String::from("upgrade-insecure-requests")), + NameField::Other(String::from("upgrade-insecure-requests")), String::from("1"), )), 97 => Some(( - Field::Other(String::from("x-frame-options")), + NameField::Other(String::from("x-frame-options")), String::from("deny"), )), 98 => Some(( - Field::Other(String::from("x-frame-options")), + NameField::Other(String::from("x-frame-options")), String::from("sameorigin"), )), _ => None, } } - /// Tries to get a `Index` by the given field and value. - fn index(field: &Field, value: &str) -> Option { + fn index(field: &NameField, value: &str) -> TableIndex { match (field, value) { - (Field::Authority, _) => Some(TableIndex::FieldName(0)), - (Field::Path, "/") => Some(TableIndex::Field(1)), - (Field::Path, _) => Some(TableIndex::FieldName(1)), - (Field::Method, "CONNECT") => Some(TableIndex::Field(15)), - (Field::Method, "DELETE") => Some(TableIndex::Field(16)), - (Field::Method, "GET") => Some(TableIndex::Field(17)), - (Field::Method, "HEAD") => Some(TableIndex::Field(18)), - (Field::Method, "OPTIONS") => Some(TableIndex::Field(19)), - (Field::Method, "POST") => Some(TableIndex::Field(20)), - (Field::Method, "PUT") => Some(TableIndex::Field(21)), - (Field::Method, _) => Some(TableIndex::FieldName(15)), - (Field::Scheme, "http") => Some(TableIndex::Field(22)), - (Field::Scheme, "https") => Some(TableIndex::Field(23)), - (Field::Scheme, _) => Some(TableIndex::FieldName(22)), - (Field::Status, "103") => Some(TableIndex::Field(24)), - (Field::Status, "200") => Some(TableIndex::Field(25)), - (Field::Status, "304") => Some(TableIndex::Field(26)), - (Field::Status, "404") => Some(TableIndex::Field(27)), - (Field::Status, "503") => Some(TableIndex::Field(28)), - (Field::Status, "100") => Some(TableIndex::Field(63)), - (Field::Status, "204") => Some(TableIndex::Field(64)), - (Field::Status, "206") => Some(TableIndex::Field(65)), - (Field::Status, "302") => Some(TableIndex::Field(66)), - (Field::Status, "400") => Some(TableIndex::Field(67)), - (Field::Status, "403") => Some(TableIndex::Field(68)), - (Field::Status, "421") => Some(TableIndex::Field(69)), - (Field::Status, "425") => Some(TableIndex::Field(70)), - (Field::Status, "500") => Some(TableIndex::Field(71)), - (Field::Status, _) => Some(TableIndex::FieldName(24)), - (Field::Other(s), v) => match (s.as_str(), v) { - ("age", "0") => Some(TableIndex::Field(2)), - ("age", _) => Some(TableIndex::FieldName(2)), - ("content-disposition", _) => Some(TableIndex::FieldName(3)), - ("content-length", "0") => Some(TableIndex::Field(4)), - ("content-length", _) => Some(TableIndex::FieldName(4)), - ("cookie", _) => Some(TableIndex::FieldName(5)), - ("date", _) => Some(TableIndex::FieldName(6)), - ("etag", _) => Some(TableIndex::FieldName(7)), - ("if-modified-since", _) => Some(TableIndex::FieldName(8)), - ("if-none-match", _) => Some(TableIndex::FieldName(9)), - ("last-modified", _) => Some(TableIndex::FieldName(10)), - ("link", _) => Some(TableIndex::FieldName(11)), - ("location", _) => Some(TableIndex::FieldName(12)), - ("referer", _) => Some(TableIndex::FieldName(13)), - ("set-cookie", _) => Some(TableIndex::FieldName(14)), - ("accept", "*/*") => Some(TableIndex::Field(29)), - ("accept", "application/dns-message") => Some(TableIndex::Field(30)), - ("accept", _) => Some(TableIndex::FieldName(29)), - ("accept-encoding", "gzip, deflate, br") => Some(TableIndex::Field(31)), - ("accept-encoding", _) => Some(TableIndex::FieldName(31)), - ("accept-ranges", "bytes") => Some(TableIndex::Field(32)), - ("accept-ranges", _) => Some(TableIndex::FieldName(32)), - ("access-control-allow-headers", "cache-control") => Some(TableIndex::Field(33)), - ("access-control-allow-headers", "content-type") => Some(TableIndex::Field(34)), - ("access-control-allow-origin", "*") => Some(TableIndex::Field(35)), - ("access-control-allow-origin", _) => Some(TableIndex::FieldName(35)), - ("cache-control", "max-age=0") => Some(TableIndex::Field(36)), - ("cache-control", "max-age=2592000") => Some(TableIndex::Field(37)), - ("cache-control", "max-age=604800") => Some(TableIndex::Field(38)), - ("cache-control", "no-cache") => Some(TableIndex::Field(39)), - ("cache-control", "no-store") => Some(TableIndex::Field(40)), - ("cache-control", "public, max-age=31536000") => Some(TableIndex::Field(41)), - ("cache-control", _) => Some(TableIndex::FieldName(36)), - ("content-encoding", "br") => Some(TableIndex::Field(42)), - ("content-encoding", "gzip") => Some(TableIndex::Field(43)), - ("content-encoding", _) => Some(TableIndex::FieldName(42)), - ("content-type", "application/dns-message") => Some(TableIndex::Field(44)), - ("content-type", "application/javascript") => Some(TableIndex::Field(45)), - ("content-type", "application/json") => Some(TableIndex::Field(46)), - ("content-type", "application/x-www-form-urlencoded") => { - Some(TableIndex::Field(47)) - } - ("content-type", "image/gif") => Some(TableIndex::Field(48)), - ("content-type", "image/jpeg") => Some(TableIndex::Field(49)), - ("content-type", "image/png") => Some(TableIndex::Field(50)), - ("content-type", "text/css") => Some(TableIndex::Field(51)), - ("content-type", "text/html; charset=utf-8") => Some(TableIndex::Field(52)), - ("content-type", "text/plain") => Some(TableIndex::Field(53)), - ("content-type", "text/plain;charset=utf-8") => Some(TableIndex::Field(54)), - ("content-type", _) => Some(TableIndex::FieldName(44)), - ("range", "bytes=0-") => Some(TableIndex::Field(55)), - ("range", _) => Some(TableIndex::FieldName(55)), - ("strict-transport-security", "max-age=31536000") => Some(TableIndex::Field(56)), + (NameField::Authority, _) => TableIndex::FieldName(0), + (NameField::Path, "/") => TableIndex::Field(1), + (NameField::Path, _) => TableIndex::FieldName(1), + (NameField::Method, "CONNECT") => TableIndex::Field(15), + (NameField::Method, "DELETE") => TableIndex::Field(16), + (NameField::Method, "GET") => TableIndex::Field(17), + (NameField::Method, "HEAD") => TableIndex::Field(18), + (NameField::Method, "OPTIONS") => TableIndex::Field(19), + (NameField::Method, "POST") => TableIndex::Field(20), + (NameField::Method, "PUT") => TableIndex::Field(21), + (NameField::Method, _) => TableIndex::FieldName(15), + (NameField::Scheme, "http") => TableIndex::Field(22), + (NameField::Scheme, "https") => TableIndex::Field(23), + (NameField::Scheme, _) => TableIndex::FieldName(22), + (NameField::Status, "103") => TableIndex::Field(24), + (NameField::Status, "200") => TableIndex::Field(25), + (NameField::Status, "304") => TableIndex::Field(26), + (NameField::Status, "404") => TableIndex::Field(27), + (NameField::Status, "503") => TableIndex::Field(28), + (NameField::Status, "100") => TableIndex::Field(63), + (NameField::Status, "204") => TableIndex::Field(64), + (NameField::Status, "206") => TableIndex::Field(65), + (NameField::Status, "302") => TableIndex::Field(66), + (NameField::Status, "400") => TableIndex::Field(67), + (NameField::Status, "403") => TableIndex::Field(68), + (NameField::Status, "421") => TableIndex::Field(69), + (NameField::Status, "425") => TableIndex::Field(70), + (NameField::Status, "500") => TableIndex::Field(71), + (NameField::Status, _) => TableIndex::FieldName(24), + (NameField::Other(s), v) => match (s.as_str(), v) { + ("age", "0") => TableIndex::Field(2), + ("age", _) => TableIndex::FieldName(2), + ("content-disposition", _) => TableIndex::FieldName(3), + ("content-length", "0") => TableIndex::Field(4), + ("content-length", _) => TableIndex::FieldName(4), + ("cookie", _) => TableIndex::FieldName(5), + ("date", _) => TableIndex::FieldName(6), + ("etag", _) => TableIndex::FieldName(7), + ("if-modified-since", _) => TableIndex::FieldName(8), + ("if-none-match", _) => TableIndex::FieldName(9), + ("last-modified", _) => TableIndex::FieldName(10), + ("link", _) => TableIndex::FieldName(11), + ("location", _) => TableIndex::FieldName(12), + ("referer", _) => TableIndex::FieldName(13), + ("set-cookie", _) => TableIndex::FieldName(14), + ("accept", "*/*") => TableIndex::Field(29), + ("accept", "application/dns-message") => TableIndex::Field(30), + ("accept", _) => TableIndex::FieldName(29), + ("accept-encoding", "gzip, deflate, br") => TableIndex::Field(31), + ("accept-encoding", _) => TableIndex::FieldName(31), + ("accept-ranges", "bytes") => TableIndex::Field(32), + ("accept-ranges", _) => TableIndex::FieldName(32), + ("access-control-allow-headers", "cache-control") => TableIndex::Field(33), + ("access-control-allow-headers", "content-type") => TableIndex::Field(34), + ("access-control-allow-origin", "*") => TableIndex::Field(35), + ("access-control-allow-origin", _) => TableIndex::FieldName(35), + ("cache-control", "max-age=0") => TableIndex::Field(36), + ("cache-control", "max-age=2592000") => TableIndex::Field(37), + ("cache-control", "max-age=604800") => TableIndex::Field(38), + ("cache-control", "no-cache") => TableIndex::Field(39), + ("cache-control", "no-store") => TableIndex::Field(40), + ("cache-control", "public, max-age=31536000") => TableIndex::Field(41), + ("cache-control", _) => TableIndex::FieldName(36), + ("content-encoding", "br") => TableIndex::Field(42), + ("content-encoding", "gzip") => TableIndex::Field(43), + ("content-encoding", _) => TableIndex::FieldName(42), + ("content-type", "application/dns-message") => TableIndex::Field(44), + ("content-type", "application/javascript") => TableIndex::Field(45), + ("content-type", "application/json") => TableIndex::Field(46), + ("content-type", "application/x-www-form-urlencoded") => TableIndex::Field(47), + ("content-type", "image/gif") => TableIndex::Field(48), + ("content-type", "image/jpeg") => TableIndex::Field(49), + ("content-type", "image/png") => TableIndex::Field(50), + ("content-type", "text/css") => TableIndex::Field(51), + ("content-type", "text/html; charset=utf-8") => TableIndex::Field(52), + ("content-type", "text/plain") => TableIndex::Field(53), + ("content-type", "text/plain;charset=utf-8") => TableIndex::Field(54), + ("content-type", _) => TableIndex::FieldName(44), + ("range", "bytes=0-") => TableIndex::Field(55), + ("range", _) => TableIndex::FieldName(55), + ("strict-transport-security", "max-age=31536000") => TableIndex::Field(56), ("strict-transport-security", "max-age=31536000; includesubdomains") => { - Some(TableIndex::Field(57)) + TableIndex::Field(57) } ("strict-transport-security", "max-age=31536000; includesubdomains; preload") => { - Some(TableIndex::Field(58)) - } - ("strict-transport-security", _) => Some(TableIndex::FieldName(56)), - ("vary", "accept-encoding") => Some(TableIndex::Field(59)), - ("vary", "origin") => Some(TableIndex::Field(60)), - ("vary", _) => Some(TableIndex::FieldName(59)), - ("x-content-type-options", "nosniff") => Some(TableIndex::Field(61)), - ("x-content-type-options", _) => Some(TableIndex::FieldName(61)), - ("x-xss-protection", "1; mode=block") => Some(TableIndex::Field(62)), - ("x-xss-protection", _) => Some(TableIndex::FieldName(62)), - ("accept-language", _) => Some(TableIndex::FieldName(72)), - ("access-control-allow-credentials", "FALSE") => Some(TableIndex::Field(73)), - ("access-control-allow-credentials", "TRUE") => Some(TableIndex::Field(74)), - ("access-control-allow-credentials", _) => Some(TableIndex::FieldName(73)), - ("access-control-allow-headers", "*") => Some(TableIndex::Field(75)), - ("access-control-allow-headers", _) => Some(TableIndex::FieldName(75)), - ("access-control-allow-methods", "get") => Some(TableIndex::Field(76)), - ("access-control-allow-methods", "get, post, options") => { - Some(TableIndex::Field(77)) + TableIndex::Field(58) } - ("access-control-allow-methods", "options") => Some(TableIndex::Field(78)), - ("access-control-allow-methods", _) => Some(TableIndex::FieldName(76)), - ("access-control-expose-headers", "content-length") => Some(TableIndex::Field(79)), - ("access-control-expose-headers", _) => Some(TableIndex::FieldName(79)), - ("access-control-request-headers", "content-type") => Some(TableIndex::Field(80)), - ("access-control-request-headers", _) => Some(TableIndex::FieldName(80)), - ("access-control-request-method", "get") => Some(TableIndex::Field(81)), - ("access-control-request-method", "post") => Some(TableIndex::Field(82)), - ("access-control-request-method", _) => Some(TableIndex::FieldName(81)), - ("alt-svc", "clear") => Some(TableIndex::Field(83)), - ("alt-svc", _) => Some(TableIndex::FieldName(83)), - ("authorization", _) => Some(TableIndex::FieldName(84)), + ("strict-transport-security", _) => TableIndex::FieldName(56), + ("vary", "accept-encoding") => TableIndex::Field(59), + ("vary", "origin") => TableIndex::Field(60), + ("vary", _) => TableIndex::FieldName(59), + ("x-content-type-options", "nosniff") => TableIndex::Field(61), + ("x-content-type-options", _) => TableIndex::FieldName(61), + ("x-xss-protection", "1; mode=block") => TableIndex::Field(62), + ("x-xss-protection", _) => TableIndex::FieldName(62), + ("accept-language", _) => TableIndex::FieldName(72), + ("access-control-allow-credentials", "FALSE") => TableIndex::Field(73), + ("access-control-allow-credentials", "TRUE") => TableIndex::Field(74), + ("access-control-allow-credentials", _) => TableIndex::FieldName(73), + ("access-control-allow-headers", "*") => TableIndex::Field(75), + ("access-control-allow-headers", _) => TableIndex::FieldName(75), + ("access-control-allow-methods", "get") => TableIndex::Field(76), + ("access-control-allow-methods", "get, post, options") => TableIndex::Field(77), + ("access-control-allow-methods", "options") => TableIndex::Field(78), + ("access-control-allow-methods", _) => TableIndex::FieldName(76), + ("access-control-expose-headers", "content-length") => TableIndex::Field(79), + ("access-control-expose-headers", _) => TableIndex::FieldName(79), + ("access-control-request-headers", "content-type") => TableIndex::Field(80), + ("access-control-request-headers", _) => TableIndex::FieldName(80), + ("access-control-request-method", "get") => TableIndex::Field(81), + ("access-control-request-method", "post") => TableIndex::Field(82), + ("access-control-request-method", _) => TableIndex::FieldName(81), + ("alt-svc", "clear") => TableIndex::Field(83), + ("alt-svc", _) => TableIndex::FieldName(83), + ("authorization", _) => TableIndex::FieldName(84), ( "content-security-policy", "script-src 'none'; object-src 'none'; base-uri 'none'", - ) => Some(TableIndex::Field(85)), - ("content-security-policy", _) => Some(TableIndex::FieldName(85)), - ("early-data", "1") => Some(TableIndex::Field(86)), - ("early-data", _) => Some(TableIndex::FieldName(86)), - ("expect-ct", _) => Some(TableIndex::FieldName(87)), - ("forwarded", _) => Some(TableIndex::FieldName(88)), - ("if-range", _) => Some(TableIndex::FieldName(89)), - ("origin", _) => Some(TableIndex::FieldName(90)), - ("purpose", "prefetch") => Some(TableIndex::Field(91)), - ("purpose", _) => Some(TableIndex::FieldName(91)), - ("server", _) => Some(TableIndex::FieldName(92)), - ("timing-allow-origin", "*") => Some(TableIndex::Field(93)), - ("timing-allow-origin", _) => Some(TableIndex::FieldName(93)), - ("upgrade-insecure-requests", "1") => Some(TableIndex::Field(94)), - ("upgrade-insecure-requests", _) => Some(TableIndex::FieldName(94)), - ("user-agent", _) => Some(TableIndex::FieldName(95)), - ("x-forwarded-for", _) => Some(TableIndex::FieldName(96)), - ("x-frame-options", "deny") => Some(TableIndex::Field(97)), - ("x-frame-options", "sameorigin") => Some(TableIndex::Field(98)), - ("x-frame-options", _) => Some(TableIndex::FieldName(97)), - _ => None, + ) => TableIndex::Field(85), + ("content-security-policy", _) => TableIndex::FieldName(85), + ("early-data", "1") => TableIndex::Field(86), + ("early-data", _) => TableIndex::FieldName(86), + ("expect-ct", _) => TableIndex::FieldName(87), + ("forwarded", _) => TableIndex::FieldName(88), + ("if-range", _) => TableIndex::FieldName(89), + ("origin", _) => TableIndex::FieldName(90), + ("purpose", "prefetch") => TableIndex::Field(91), + ("purpose", _) => TableIndex::FieldName(91), + ("server", _) => TableIndex::FieldName(92), + ("timing-allow-origin", "*") => TableIndex::Field(93), + ("timing-allow-origin", _) => TableIndex::FieldName(93), + ("upgrade-insecure-requests", "1") => TableIndex::Field(94), + ("upgrade-insecure-requests", _) => TableIndex::FieldName(94), + ("user-agent", _) => TableIndex::FieldName(95), + ("x-forwarded-for", _) => TableIndex::FieldName(96), + ("x-frame-options", "deny") => TableIndex::Field(97), + ("x-frame-options", "sameorigin") => TableIndex::Field(98), + ("x-frame-options", _) => TableIndex::FieldName(97), + _ => TableIndex::None, }, } } } #[derive(Clone, PartialEq, Eq, Debug)] -pub enum Field { +pub enum NameField { Authority, Method, Path, @@ -696,26 +798,28 @@ pub enum Field { Other(String), } -impl Field { +impl NameField { pub(crate) fn len(&self) -> usize { match self { - Field::Authority => 10, // 10 is the length of ":authority". - Field::Method => 7, // 7 is the length of ":method". - Field::Path => 5, // 5 is the length of ":path". - Field::Scheme => 7, // 7 is the length of "scheme". - Field::Status => 7, // 7 is the length of "status". - Field::Other(s) => s.len(), + NameField::Authority => 10, // 10 is the length of ":authority". + NameField::Method => 7, // 7 is the length of ":method". + NameField::Path => 5, // 5 is the length of ":path". + NameField::Scheme => 7, // 7 is the length of "scheme". + NameField::Status => 7, // 7 is the length of "status". + NameField::Other(s) => s.len(), } } +} - pub(crate) fn into_string(self) -> String { +impl ToString for NameField { + fn to_string(&self) -> String { match self { - Field::Authority => String::from(":authority"), - Field::Method => String::from(":method"), - Field::Path => String::from(":path"), - Field::Scheme => String::from(":scheme"), - Field::Status => String::from(":status"), - Field::Other(s) => s, + NameField::Authority => String::from(":authority"), + NameField::Method => String::from(":method"), + NameField::Path => String::from(":path"), + NameField::Scheme => String::from(":scheme"), + NameField::Status => String::from(":status"), + NameField::Other(s) => s.clone(), } } } diff --git a/ylong_http/src/h3/stream.rs b/ylong_http/src/h3/stream.rs index 0d03a7e..b050be1 100644 --- a/ylong_http/src/h3/stream.rs +++ b/ylong_http/src/h3/stream.rs @@ -10,3 +10,128 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + +use std::ops::{Deref, DerefMut}; + +use crate::h3::Frame; + +/// HTTP3 control stream type code. +pub const CONTROL_STREAM_TYPE: u8 = 0x0; +/// HTTP3 push stream type code. +pub const PUSH_STREAM_TYPE: u8 = 0x1; +/// qpack encoder stream type code. +pub const QPACK_ENCODER_STREAM_TYPE: u8 = 0x2; +/// qpack decoder stream type code. +pub const QPACK_DECODER_STREAM_TYPE: u8 = 0x3; + +/// Http3 decoded frames. +pub struct Frames { + list: Vec, +} + +/// An iterator of `Frames`. +pub struct FramesIter<'a> { + iter: core::slice::Iter<'a, FrameKind>, +} + +/// A consuming iterator of `Frames`. +pub struct FramesIntoIter { + into_iter: std::vec::IntoIter, +} + +/// Qpack decoder the state after decoding a frame. +pub enum FrameKind { + /// PUSH_PROMISE or HEADERS frame parsing completed. + Complete(Box), + /// Partial decoded of PUSH_PROMISE or HEADERS frame. + Partial, + /// Headers part is blocked at Qpack decode. + Blocked, +} + +impl Frames { + /// Gets an iterator for fFrames. + pub fn iter(&self) -> FramesIter { + FramesIter { + iter: self.list.iter(), + } + } + + /// Returns the size of `Frames`. + pub fn len(&self) -> usize { + self.list.len() + } + + /// Checks if the `Frames` is empty. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } +} + +impl Frames { + pub(crate) fn new() -> Self { + Frames { list: vec![] } + } + pub(crate) fn push(&mut self, frame: FrameKind) { + self.list.push(frame) + } +} + +impl<'a> Deref for FramesIter<'a> { + type Target = core::slice::Iter<'a, FrameKind>; + + fn deref(&self) -> &Self::Target { + &self.iter + } +} + +impl<'a> DerefMut for FramesIter<'a> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.iter + } +} + +impl<'a> Iterator for FramesIter<'a> { + type Item = &'a FrameKind; + + fn next(&mut self) -> Option { + self.iter.next() + } +} + +impl Iterator for FramesIntoIter { + type Item = FrameKind; + + fn next(&mut self) -> Option { + self.into_iter.next() + } +} + +impl core::iter::IntoIterator for Frames { + type Item = FrameKind; + type IntoIter = FramesIntoIter; + + fn into_iter(self) -> Self::IntoIter { + FramesIntoIter { + into_iter: self.list.into_iter(), + } + } +} + +/// The Http3 decoder deserializes the data. +pub enum StreamMessage { + /// Request stream message. + Request(Frames), + /// Control stream message. + Control(Frames), + /// Push stream message. + Push(u64, Frames), + /// Qpack encoder stream message. + QpackEncoder(Vec), + /// Qpack decoder stream message. + QpackDecoder(Vec), + /// Unknown stream message. + Unknown, + /// Bytes too short to decode a stream. + WaitingMore, +} diff --git a/ylong_http/src/request/uri/mod.rs b/ylong_http/src/request/uri/mod.rs index 78818e2..b3c716f 100644 --- a/ylong_http/src/request/uri/mod.rs +++ b/ylong_http/src/request/uri/mod.rs @@ -775,6 +775,30 @@ impl Host { } } +impl core::str::FromStr for Host { + type Err = HttpError; + + /// Constructs host from a string slice. + /// + /// # Examples + /// + /// ``` + /// use std::str::FromStr; + /// + /// use ylong_http::request::uri::Host; + /// + /// let host = Host::from_str("www.example.com").unwrap(); + /// assert_eq!(host.as_str(), "www.example.com"); + /// ``` + fn from_str(host: &str) -> Result { + if host.is_empty() { + Err(InvalidUri::UriMissHost.into()) + } else { + Ok(Self(String::from(host))) + } + } +} + impl ToString for Host { fn to_string(&self) -> String { self.0.to_owned() @@ -803,7 +827,7 @@ impl Port { self.0.as_str() } - /// Returns a u16 value of the `Port`. + /// Returns an u16 value of the `Port`. /// /// # Examples /// @@ -821,6 +845,27 @@ impl Port { } } +impl core::str::FromStr for Port { + type Err = HttpError; + + /// Constructs host from a string slice. + /// + /// # Examples + /// + /// ``` + /// use std::str::FromStr; + /// + /// use ylong_http::request::uri::Port; + /// + /// let host = Port::from_str("80").unwrap(); + /// assert_eq!(host.as_str(), "80"); + /// ``` + fn from_str(port: &str) -> Result { + port.parse::().map_err(|_| InvalidUri::InvalidPort)?; + Ok(Self(String::from(port))) + } +} + /// Path component of [`Uri`]. /// /// [`Uri`]: Uri diff --git a/ylong_http/src/response/status.rs b/ylong_http/src/response/status.rs index ea25434..c5ad299 100644 --- a/ylong_http/src/response/status.rs +++ b/ylong_http/src/response/status.rs @@ -281,7 +281,6 @@ impl Display for StatusCode { } } -#[rustfmt::skip] status_list!( /// [`100 Continue`]: https://tools.ietf.org/html/rfc7231#section-6.2.1 (100, CONTINUE, "Continue"), diff --git a/ylong_http_client/Cargo.toml b/ylong_http_client/Cargo.toml index ba9babf..2dad5a0 100644 --- a/ylong_http_client/Cargo.toml +++ b/ylong_http_client/Cargo.toml @@ -9,6 +9,7 @@ keywords = ["ylong", "http", "client"] [dependencies] ylong_http = { path = "../ylong_http" } +quiche = { version = "0.22.0", features = ["ffi"], optional = true } libc = { version = "0.2.134", optional = true } tokio = { version = "1.20.1", features = ["io-util", "net", "rt", "rt-multi-thread", "macros", "sync", "time"], optional = true } ylong_runtime = { git = "https://gitee.com/openharmony/commonlibrary_rust_ylong_runtime.git", features = ["net", "sync", "fs", "macros", "time"], optional = true } @@ -32,7 +33,7 @@ sync = [] # Uses sync interfaces. async = [] # Uses async interfaces. http1_1 = ["ylong_http/http1_1"] # Uses HTTP/1.1. http2 = ["ylong_http/http2", "ylong_http/huffman"] # Uses HTTP/2. -http3 = [] # Uses HTTP/3. +http3 = ["ylong_http/http3", "quiche"] # Uses HTTP/3. tokio_base = ["tokio", "ylong_http/tokio_base"] # Uses tokio runtime. ylong_base = ["ylong_runtime", "ylong_http/ylong_base"] # Uses ylong runtime. @@ -42,6 +43,7 @@ __tls = [] # Not open to user, only mark to use tls __c_openssl = ["__tls", "libc"] # Not open to user, only mark to use tls by C-openssl for developer. c_openssl_1_1 = ["__c_openssl"] # Uses TLS by FFI of C-openssl 1.1. c_openssl_3_0 = ["__c_openssl"] # Uses TLS by FFI of C-openssl 3.0. +c_boringssl = ["__tls", "libc"] [[example]] name = "async_certs_adapter" diff --git a/ylong_http_client/build.rs b/ylong_http_client/build.rs index 20b4d74..af0c003 100644 --- a/ylong_http_client/build.rs +++ b/ylong_http_client/build.rs @@ -18,7 +18,7 @@ //! ``OPENSSL_INCLUDE_DIR`` is the path for the Openssl header file. use std::env; - +// todo: check if needed fn main() { let lib_dir = env::var("OPENSSL_LIB_DIR"); let include_dir = env::var("OPENSSL_INCLUDE_DIR"); diff --git a/ylong_http_client/src/async_impl/client.rs b/ylong_http_client/src/async_impl/client.rs index 0606723..19d6e14 100644 --- a/ylong_http_client/src/async_impl/client.rs +++ b/ylong_http_client/src/async_impl/client.rs @@ -22,9 +22,9 @@ use crate::async_impl::interceptor::{IdleInterceptor, Interceptor, Interceptors} use crate::async_impl::request::Message; use crate::error::HttpClientError; use crate::runtime::timeout; -#[cfg(feature = "__c_openssl")] +#[cfg(feature = "__tls")] use crate::util::c_openssl::verify::PubKeyPins; -#[cfg(all(target_os = "linux", feature = "ylong_base", feature = "__c_openssl"))] +#[cfg(all(target_os = "linux", feature = "ylong_base", feature = "__tls"))] use crate::util::config::FchownConfig; use crate::util::config::{ ClientConfig, ConnectorConfig, HttpConfig, HttpVersion, Proxy, Redirect, Timeout, @@ -34,7 +34,7 @@ use crate::util::normalizer::RequestFormatter; use crate::util::proxy::Proxies; use crate::util::redirect::{RedirectInfo, Trigger}; use crate::util::request::RequestArc; -#[cfg(feature = "__c_openssl")] +#[cfg(feature = "__tls")] use crate::CertVerifier; use crate::{ErrorKind, Retry}; @@ -160,8 +160,11 @@ impl Client { impl Client { async fn send_request(&self, request: RequestArc) -> Result { - let response = self.send_unformatted_request(request.clone()).await?; - self.redirect(response, request).await + let mut response = self.send_unformatted_request(request.clone()).await?; + response = self.redirect(response, request.clone()).await?; + #[cfg(feature = "http3")] + self.inner.set_alt_svcs(request, &response); + Ok(response) } async fn send_unformatted_request( @@ -262,7 +265,7 @@ pub struct ClientBuilder { /// Options and flags that is related to `Proxy`. proxies: Proxies, - #[cfg(all(target_os = "linux", feature = "ylong_base", feature = "__c_openssl"))] + #[cfg(all(target_os = "linux", feature = "ylong_base", feature = "__tls"))] /// Fchown configuration. fchown: Option, @@ -288,7 +291,7 @@ impl ClientBuilder { http: HttpConfig::default(), client: ClientConfig::default(), proxies: Proxies::default(), - #[cfg(all(target_os = "linux", feature = "ylong_base", feature = "__c_openssl"))] + #[cfg(all(target_os = "linux", feature = "ylong_base", feature = "__tls"))] fchown: None, interceptors: Arc::new(IdleInterceptor), #[cfg(feature = "__tls")] @@ -374,7 +377,7 @@ impl ClientBuilder { /// /// let builder = ClientBuilder::new().sockets_owner(1000, 1000); /// ``` - #[cfg(all(target_os = "linux", feature = "ylong_base", feature = "__c_openssl"))] + #[cfg(all(target_os = "linux", feature = "ylong_base", feature = "__tls"))] pub fn sockets_owner(mut self, uid: u32, gid: u32) -> Self { self.fchown = Some(FchownConfig::new(uid, gid)); self @@ -448,28 +451,32 @@ impl ClientBuilder { /// let client = ClientBuilder::new().build(); /// ``` pub fn build(self) -> Result, HttpClientError> { - #[cfg(feature = "__c_openssl")] + #[cfg(feature = "__tls")] use crate::util::{AlpnProtocol, AlpnProtocolList}; - #[cfg(feature = "__c_openssl")] + #[cfg(feature = "__tls")] let origin_builder = self.tls; - #[cfg(feature = "__c_openssl")] + #[cfg(feature = "__tls")] let tls_builder = match self.http.version { HttpVersion::Http1 => origin_builder, #[cfg(feature = "http2")] HttpVersion::Http2 => origin_builder.alpn_protos(AlpnProtocol::H2.wire_format_bytes()), HttpVersion::Negotiate => { let supported = AlpnProtocolList::new(); + #[cfg(feature = "http3")] + let supported = supported.extend(AlpnProtocol::H3); #[cfg(feature = "http2")] let supported = supported.extend(AlpnProtocol::H2); let supported = supported.extend(AlpnProtocol::HTTP11); origin_builder.alpn_proto_list(supported) } + #[cfg(feature = "http3")] + HttpVersion::Http3 => origin_builder.alpn_protos(AlpnProtocol::H3.wire_format_bytes()), }; let config = ConnectorConfig { proxies: self.proxies, - #[cfg(all(target_os = "linux", feature = "ylong_base", feature = "__c_openssl"))] + #[cfg(all(target_os = "linux", feature = "ylong_base", feature = "__tls"))] fchown: self.fchown, #[cfg(feature = "__tls")] tls: tls_builder.build()?, @@ -603,6 +610,65 @@ impl ClientBuilder { } } +#[cfg(feature = "http3")] +impl ClientBuilder { + /// Only use HTTP/3. + /// + /// # Examples + /// + /// ``` + /// use ylong_http_client::async_impl::ClientBuilder; + /// + /// let builder = ClientBuilder::new().http3_prior_knowledge(); + /// ``` + pub fn http3_prior_knowledge(mut self) -> Self { + self.http.version = HttpVersion::Http3; + self + } + + /// Sets the `SETTINGS_MAX_FIELD_SECTION_SIZE` defined in RFC9114 + /// + /// # Examples + /// + /// ``` + /// use ylong_http_client::async_impl::ClientBuilder; + /// + /// let builder = ClientBuilder::new().set_http3_max_field_section_size(16 * 1024); + /// ``` + pub fn set_http3_max_field_section_size(mut self, size: u64) -> Self { + self.http.http3_config.set_max_field_section_size(size); + self + } + + /// Sets the `SETTINGS_QPACK_MAX_TABLE_CAPACITY` defined in RFC9204 + /// + /// # Examples + /// + /// ``` + /// use ylong_http_client::async_impl::ClientBuilder; + /// + /// let builder = ClientBuilder::new().set_http3_qpack_max_table_capacity(16 * 1024); + /// ``` + pub fn set_http3_qpack_max_table_capacity(mut self, size: u64) -> Self { + self.http.http3_config.set_qpack_max_table_capacity(size); + self + } + + /// Sets the `SETTINGS_QPACK_BLOCKED_STREAMS` defined in RFC9204 + /// + /// # Examples + /// + /// ``` + /// use ylong_http_client::async_impl::ClientBuilder; + /// + /// let builder = ClientBuilder::new().set_http3_qpack_blocked_streams(10); + /// ``` + pub fn set_http3_qpack_blocked_streams(mut self, size: u64) -> Self { + self.http.http3_config.set_qpack_blocked_streams(size); + self + } +} + #[cfg(feature = "__tls")] impl ClientBuilder { /// Sets the maximum allowed TLS version for connections. @@ -661,7 +727,6 @@ impl ClientBuilder { CertificateList::CertList(c) => { self.tls = self.tls.add_root_certificates(c); } - #[cfg(feature = "c_openssl_3_0")] CertificateList::PathList(p) => { self.tls = self.tls.add_path_certificates(p); } @@ -725,27 +790,6 @@ impl ClientBuilder { self } - /// Sets the list of supported ciphers for the `TLSv1.3` protocol. - /// - /// The format consists of TLSv1.3 cipher suite names separated by `:` - /// characters in order of preference. - /// - /// Requires `OpenSSL 1.1.1` or `LibreSSL 3.4.0` or newer. - /// - /// # Examples - /// - /// ``` - /// use ylong_http_client::async_impl::ClientBuilder; - /// - /// let builder = ClientBuilder::new().tls_cipher_suites( - /// "DEFAULT:!aNULL:!eNULL:!MD5:!3DES:!DES:!RC4:!IDEA:!SEED:!aDSS:!SRP:!PSK", - /// ); - /// ``` - pub fn tls_cipher_suites(mut self, list: &str) -> Self { - self.tls = self.tls.cipher_suites(list); - self - } - /// Controls the use of built-in system certificates during certificate /// validation. Default to `true` -- uses built-in system certs. /// @@ -1063,9 +1107,6 @@ HJMRZVCQpSMzvHlofHSNgzWV1MX5h1CP4SGZdBDTfA== .tls_cipher_list( "DEFAULT:!aNULL:!eNULL:!MD5:!3DES:!DES:!RC4:!IDEA:!SEED:!aDSS:!SRP:!PSK", ) - .tls_cipher_suites( - "DEFAULT:!aNULL:!eNULL:!MD5:!3DES:!DES:!RC4:!IDEA:!SEED:!aDSS:!SRP:!PSK", - ) .tls_built_in_root_certs(false) .danger_accept_invalid_certs(false) .danger_accept_invalid_hostnames(false) diff --git a/ylong_http_client/src/async_impl/conn/http2.rs b/ylong_http_client/src/async_impl/conn/http2.rs index 4dd2457..1933466 100644 --- a/ylong_http_client/src/async_impl/conn/http2.rs +++ b/ylong_http_client/src/async_impl/conn/http2.rs @@ -31,8 +31,9 @@ use crate::async_impl::request::Message; use crate::async_impl::{HttpBody, Response}; use crate::error::{ErrorKind, HttpClientError}; use crate::runtime::{AsyncRead, ReadBuf}; +use crate::util::data_ref::BodyDataRef; use crate::util::dispatcher::http2::Http2Conn; -use crate::util::h2::{BodyDataRef, RequestWrapper}; +use crate::util::h2::RequestWrapper; use crate::util::normalizer::BodyLengthParser; const UNUSED_FLAG: u8 = 0x0; diff --git a/ylong_http_client/src/async_impl/conn/http3.rs b/ylong_http_client/src/async_impl/conn/http3.rs new file mode 100644 index 0000000..3956134 --- /dev/null +++ b/ylong_http_client/src/async_impl/conn/http3.rs @@ -0,0 +1,320 @@ +// Copyright (c) 2023 Huawei Device Co., Ltd. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::cmp::min; +use std::ops::Deref; +use std::pin::Pin; +use std::sync::atomic::Ordering; +use std::task::{Context, Poll}; + +use ylong_http::error::HttpError; +use ylong_http::h3::{ + Frame, H3Error, H3ErrorCode, Headers, Parts, Payload, PseudoHeaders, HEADERS_FRAME_TYPE, +}; +use ylong_http::request::uri::Scheme; +use ylong_http::request::RequestPart; +use ylong_http::response::status::StatusCode; +use ylong_http::response::ResponsePart; +use ylong_runtime::io::ReadBuf; + +use crate::async_impl::conn::StreamData; +use crate::async_impl::request::Message; +use crate::async_impl::{HttpBody, Response}; +use crate::runtime::AsyncRead; +use crate::util::data_ref::BodyDataRef; +use crate::util::dispatcher::http3::{DispatchErrorKind, Http3Conn, RequestWrapper, RespMessage}; +use crate::util::normalizer::BodyLengthParser; +use crate::{ErrorKind, HttpClientError}; + +pub(crate) async fn request( + mut conn: Http3Conn, + mut message: Message, +) -> Result +where + S: Sync + Send + Unpin + 'static, +{ + message + .interceptor + .intercept_request(message.request.ref_mut())?; + let part = message.request.ref_mut().part().clone(); + + // TODO Implement trailer. + let headers = build_headers_frame(part) + .map_err(|e| HttpClientError::from_error(ErrorKind::Request, e))?; + let data = BodyDataRef::new(message.request.clone()); + let stream = RequestWrapper { + header: headers, + data, + }; + conn.send_frame_to_reader(stream)?; + let frame = conn.recv_resp().await?; + frame_2_response(conn, frame, message) +} + +pub(crate) fn build_headers_frame(mut part: RequestPart) -> Result { + // todo: check rfc to see if any headers should be removed + let pseudo = build_pseudo_headers(&mut part)?; + let mut header_part = Parts::new(); + header_part.set_header_lines(part.headers); + header_part.set_pseudo(pseudo); + let headers_payload = Headers::new(header_part); + + Ok(Frame::new( + HEADERS_FRAME_TYPE, + Payload::Headers(headers_payload), + )) +} + +// todo: error if headers not enough, should meet rfc +fn build_pseudo_headers(request_part: &mut RequestPart) -> Result { + let mut pseudo = PseudoHeaders::default(); + match request_part.uri.scheme() { + Some(scheme) => { + pseudo.set_scheme(Some(String::from(scheme.as_str()))); + } + None => pseudo.set_scheme(Some(String::from(Scheme::HTTPS.as_str()))), + } + pseudo.set_method(Some(String::from(request_part.method.as_str()))); + pseudo.set_path( + request_part + .uri + .path_and_query() + .or_else(|| Some(String::from("/"))), + ); + let host = request_part + .headers + .remove("host") + .and_then(|auth| auth.to_string().ok()); + pseudo.set_authority(host); + Ok(pseudo) +} + +fn frame_2_response( + conn: Http3Conn, + headers_frame: Frame, + mut message: Message, +) -> Result +where + S: Sync + Send + Unpin + 'static, +{ + let part = match headers_frame.payload() { + Payload::Headers(headers) => { + let part = headers.get_part(); + let (pseudo, fields) = part.parts(); + let status_code = match pseudo.status() { + Some(status) => StatusCode::from_bytes(status.as_bytes()) + .map_err(|e| HttpClientError::from_error(ErrorKind::Request, e))?, + None => { + return Err(HttpClientError::from_str( + ErrorKind::Request, + "status code not found", + )); + } + }; + ResponsePart { + version: ylong_http::version::Version::HTTP3, + status: status_code, + headers: fields.clone(), + } + } + Payload::PushPromise(_) => { + todo!(); + } + _ => { + return Err(HttpClientError::from_str(ErrorKind::Request, "bad frame")); + } + }; + + let data_io = TextIo::new(conn); + let length = match BodyLengthParser::new(message.request.ref_mut().method(), &part).parse() { + Ok(length) => length, + Err(e) => { + return Err(e); + } + }; + let body = HttpBody::new(message.interceptor, length, Box::new(data_io), &[0u8; 0])?; + + Ok(Response::new( + ylong_http::response::Response::from_raw_parts(part, body), + )) +} + +struct TextIo { + pub(crate) handle: Http3Conn, + pub(crate) offset: usize, + pub(crate) remain: Option, + pub(crate) is_closed: bool, +} + +struct HttpReadBuf<'a, 'b> { + buf: &'a mut ReadBuf<'b>, +} + +impl<'a, 'b> HttpReadBuf<'a, 'b> { + pub(crate) fn append_slice(&mut self, buf: &[u8]) { + #[cfg(feature = "ylong_base")] + self.buf.append(buf); + + #[cfg(feature = "tokio_base")] + self.buf.put_slice(buf); + } +} + +impl<'a, 'b> Deref for HttpReadBuf<'a, 'b> { + type Target = ReadBuf<'b>; + + fn deref(&self) -> &Self::Target { + self.buf + } +} + +impl TextIo +where + S: Sync + Send + Unpin + 'static, +{ + pub(crate) fn new(handle: Http3Conn) -> Self { + Self { + handle, + offset: 0, + remain: None, + is_closed: false, + } + } + + fn match_channel_message( + poll_result: Poll, + text_io: &mut TextIo, + buf: &mut HttpReadBuf, + ) -> Option>> { + match poll_result { + Poll::Ready(RespMessage::Output(frame)) => match frame.payload() { + Payload::Headers(_) => { + text_io.remain = Some(frame); + text_io.offset = 0; + Some(Poll::Ready(Ok(()))) + } + Payload::Data(data) => { + let data = data.data(); + let unfilled_len = buf.remaining(); + let data_len = data.len(); + let fill_len = min(data_len, unfilled_len); + if unfilled_len < data_len { + buf.append_slice(&data[..fill_len]); + text_io.offset += fill_len; + text_io.remain = Some(frame); + Some(Poll::Ready(Ok(()))) + } else { + buf.append_slice(&data[..fill_len]); + Self::end_read(text_io, data_len) + } + } + _ => Some(Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::Other, + HttpError::from(H3Error::Connection(H3ErrorCode::H3InternalError)), + )))), + }, + Poll::Ready(RespMessage::OutputExit(e)) => match e { + DispatchErrorKind::H3(H3Error::Connection(H3ErrorCode::H3NoError)) + | DispatchErrorKind::StreamFinished => { + text_io.is_closed = true; + Some(Poll::Ready(Ok(()))) + } + _ => Some(Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::Other, + HttpError::from(H3Error::Connection(H3ErrorCode::H3InternalError)), + )))), + }, + Poll::Pending => Some(Poll::Pending), + } + } + + fn end_read(text_io: &mut TextIo, data_len: usize) -> Option>> { + text_io.offset = 0; + text_io.remain = None; + if data_len == 0 { + // no data read and is not end stream. + None + } else { + Some(Poll::Ready(Ok(()))) + } + } + + fn read_remaining_data( + text_io: &mut TextIo, + buf: &mut HttpReadBuf, + ) -> Option>> { + if let Some(frame) = &text_io.remain { + return match frame.payload() { + Payload::Headers(_) => Some(Poll::Ready(Ok(()))), + Payload::Data(data) => { + let data = data.data(); + let unfilled_len = buf.remaining(); + let data_len = data.len() - text_io.offset; + let fill_len = min(unfilled_len, data_len); + // The peripheral function already ensures that the remaing of buf will not be + // 0. + if unfilled_len < data_len { + buf.append_slice(&data[text_io.offset..text_io.offset + fill_len]); + text_io.offset += fill_len; + Some(Poll::Ready(Ok(()))) + } else { + buf.append_slice(&data[text_io.offset..text_io.offset + fill_len]); + Self::end_read(text_io, data_len) + } + } + _ => Some(Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::Other, + HttpError::from(H3Error::Connection(H3ErrorCode::H3InternalError)), + )))), + }; + } + None + } +} + +impl StreamData for TextIo { + fn shutdown(&self) { + self.handle.io_shutdown.store(true, Ordering::Relaxed); + } +} + +impl AsyncRead for TextIo { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let text_io = self.get_mut(); + let mut buf = HttpReadBuf { buf }; + + if buf.remaining() == 0 || text_io.is_closed { + return Poll::Ready(Ok(())); + } + while buf.remaining() != 0 { + if let Some(result) = Self::read_remaining_data(text_io, &mut buf) { + return result; + } + + let poll_result = text_io + .handle + .resp_receiver + .poll_recv(cx) + .map_err(|_e| std::io::Error::from(std::io::ErrorKind::ConnectionAborted))?; + + if let Some(result) = Self::match_channel_message(poll_result, text_io, &mut buf) { + return result; + } + } + Poll::Ready(Ok(())) + } +} diff --git a/ylong_http_client/src/async_impl/conn/mod.rs b/ylong_http_client/src/async_impl/conn/mod.rs index 17405b1..ee1ee31 100644 --- a/ylong_http_client/src/async_impl/conn/mod.rs +++ b/ylong_http_client/src/async_impl/conn/mod.rs @@ -17,6 +17,9 @@ mod http1; #[cfg(feature = "http2")] mod http2; +#[cfg(feature = "http3")] +mod http3; + use crate::async_impl::connector::ConnInfo; use crate::async_impl::request::Message; use crate::async_impl::Response; @@ -41,5 +44,8 @@ where #[cfg(feature = "http2")] Conn::Http2(http2) => http2::request(http2, message).await, + + #[cfg(feature = "http3")] + Conn::Http3(http3) => http3::request(http3, message).await, } } diff --git a/ylong_http_client/src/async_impl/connector/mod.rs b/ylong_http_client/src/async_impl/connector/mod.rs index 73ae085..cf88b99 100644 --- a/ylong_http_client/src/async_impl/connector/mod.rs +++ b/ylong_http_client/src/async_impl/connector/mod.rs @@ -20,9 +20,11 @@ use core::future::Future; /// Information of an IO. pub use stream::ConnInfo; use ylong_http::request::uri::Uri; +#[cfg(feature = "http3")] +use ylong_runtime::net::{ConnectedUdpSocket, UdpSocket}; use crate::runtime::{AsyncRead, AsyncWrite, TcpStream}; -use crate::util::config::ConnectorConfig; +use crate::util::config::{ConnectorConfig, HttpVersion}; use crate::HttpClientError; /// `Connector` trait used by `async_impl::Client`. `Connector` provides @@ -39,7 +41,7 @@ pub trait Connector { + 'static; /// Attempts to establish a connection. - fn connect(&self, uri: &Uri) -> Self::Future; + fn connect(&self, uri: &Uri, http_version: HttpVersion) -> Self::Future; } /// Connector for creating HTTP or HTTPS connections asynchronously. @@ -79,6 +81,22 @@ async fn tcp_stream(addr: &str) -> Result { }) } +#[cfg(feature = "http3")] +pub(crate) async fn udp_stream( + addr: &std::net::SocketAddr, +) -> Result { + let local_addr = match addr { + std::net::SocketAddr::V4(_) => "0.0.0.0:0", + std::net::SocketAddr::V6(_) => "[::]:0", + }; + let sock = UdpSocket::bind(local_addr) + .await + .map_err(|e| HttpClientError::from_io_error(crate::ErrorKind::Connect, e))?; + sock.connect(addr) + .await + .map_err(|e| HttpClientError::from_io_error(crate::ErrorKind::Connect, e)) +} + #[cfg(not(feature = "__tls"))] mod no_tls { use core::future::Future; @@ -140,18 +158,22 @@ mod tls { use super::{tcp_stream, Connector, HttpConnector}; use crate::async_impl::connector::stream::HttpStream; use crate::async_impl::interceptor::{ConnDetail, ConnProtocol}; - use crate::async_impl::ssl_stream::{AsyncSslStream, MixStream}; - #[cfg(all(target_os = "linux", feature = "ylong_base", feature = "__c_openssl"))] + use crate::async_impl::mix::MixStream; + #[cfg(feature = "http3")] + use crate::async_impl::quic::QuicConn; + use crate::async_impl::ssl_stream::AsyncSslStream; + #[cfg(all(target_os = "linux", feature = "ylong_base", feature = "__tls"))] use crate::config::FchownConfig; use crate::runtime::{AsyncReadExt, AsyncWriteExt, TcpStream}; + use crate::util::config::HttpVersion; use crate::{HttpClientError, TlsConfig}; impl Connector for HttpConnector { - type Stream = HttpStream>; + type Stream = HttpStream; type Future = Pin> + Sync + Send>>; - fn connect(&self, uri: &Uri) -> Self::Future { + fn connect(&self, uri: &Uri, _http_version: HttpVersion) -> Self::Future { // Make sure all parts of uri is accurate. let mut addr = uri.authority().unwrap().to_string(); let mut auth = None; @@ -167,16 +189,12 @@ mod tls { .and_then(|v| v.to_string().ok()); is_proxy = true; } - #[cfg(all(target_os = "linux", feature = "ylong_base", feature = "__c_openssl"))] + #[cfg(all(target_os = "linux", feature = "ylong_base", feature = "__tls"))] let fchown = self.config.fchown.clone(); match *uri.scheme().unwrap() { Scheme::HTTP => Box::pin(async move { let stream = tcp_stream(&addr).await?; - #[cfg(all( - target_os = "linux", - feature = "ylong_base", - feature = "__c_openssl" - ))] + #[cfg(all(target_os = "linux", feature = "ylong_base", feature = "__tls"))] if let Some(fchown) = fchown { let _ = stream.fchown(fchown.uid, fchown.gid); } @@ -202,19 +220,63 @@ mod tls { let host = uri.host().unwrap().to_string(); let port = uri.port().unwrap().as_u16().unwrap(); let config = self.config.tls.clone(); + #[cfg(feature = "http3")] + if _http_version == HttpVersion::Http3 { + return Box::pin(async move { + let addrs = std::net::ToSocketAddrs::to_socket_addrs(&addr.clone()) + .map_err(|e| { + HttpClientError::from_io_error(crate::ErrorKind::Connect, e) + })?; + + let mut last_e = None; + for addr_it in addrs { + let udp_socket = match super::udp_stream(&addr_it).await { + Ok(socket) => socket, + Err(e) => { + last_e = Some(e); + continue; + } + }; + let local = udp_socket.local_addr().map_err(|e| { + HttpClientError::from_io_error(crate::ErrorKind::Connect, e) + })?; + let peer = udp_socket.peer_addr().map_err(|e| { + HttpClientError::from_io_error(crate::ErrorKind::Connect, e) + })?; + let detail = ConnDetail { + protocol: ConnProtocol::Udp, + alpn: None, + local, + peer, + addr: addr.clone(), + proxy: false, + }; + let mut stream = + HttpStream::new(MixStream::Udp(udp_socket), detail); + let Ok(quic_conn) = + QuicConn::connect(&mut stream, &config, &host).await + else { + continue; + }; + stream.set_quic_conn(quic_conn); + return Ok(stream); + } + + Err(last_e.unwrap_or(HttpClientError::from_str( + crate::ErrorKind::Connect, + "connect failed", + ))) + }); + } Box::pin(async move { - #[cfg(all( - target_os = "linux", - feature = "ylong_base", - feature = "__c_openssl" - ))] + #[cfg(all(target_os = "linux", feature = "ylong_base", feature = "__tls"))] { https_connect(config, addr, is_proxy, auth, host, port, fchown).await } #[cfg(not(all( target_os = "linux", feature = "ylong_base", - feature = "__c_openssl" + feature = "__tls" )))] { https_connect(config, addr, is_proxy, auth, host, port).await @@ -232,11 +294,12 @@ mod tls { auth: Option, host: String, port: u16, - #[cfg(all(target_os = "linux", feature = "ylong_base", feature = "__c_openssl"))] - fchown: Option, - ) -> Result>, HttpClientError> { + #[cfg(all(target_os = "linux", feature = "ylong_base", feature = "__tls"))] fchown: Option< + FchownConfig, + >, + ) -> Result, HttpClientError> { let mut tcp = tcp_stream(addr.as_str()).await?; - #[cfg(all(target_os = "linux", feature = "ylong_base", feature = "__c_openssl"))] + #[cfg(all(target_os = "linux", feature = "ylong_base", feature = "__tls"))] if let Some(fchown) = fchown { let _ = tcp.fchown(fchown.uid, fchown.gid); } diff --git a/ylong_http_client/src/async_impl/connector/stream.rs b/ylong_http_client/src/async_impl/connector/stream.rs index 2398944..cd0419a 100644 --- a/ylong_http_client/src/async_impl/connector/stream.rs +++ b/ylong_http_client/src/async_impl/connector/stream.rs @@ -17,6 +17,8 @@ use std::pin::Pin; use std::task::{Context, Poll}; use crate::async_impl::interceptor::ConnDetail; +#[cfg(feature = "http3")] +use crate::async_impl::quic::QuicConn; use crate::runtime::{AsyncRead, AsyncWrite, ReadBuf}; /// `ConnDetail` trait, which is used to obtain information about the current @@ -25,13 +27,20 @@ pub trait ConnInfo { /// Whether the current connection is a proxy. fn is_proxy(&self) -> bool; + /// Gets connection information. fn conn_detail(&self) -> ConnDetail; + + /// Gets quic information + #[cfg(feature = "http3")] + fn quic_conn(&mut self) -> Option; } /// A connection wrapper containing io and io information. pub struct HttpStream { detail: ConnDetail, stream: T, + #[cfg(feature = "http3")] + quic_conn: Option, } impl AsyncRead for HttpStream @@ -78,11 +87,26 @@ impl ConnInfo for HttpStream { fn conn_detail(&self) -> ConnDetail { self.detail.clone() } + + #[cfg(feature = "http3")] + fn quic_conn(&mut self) -> Option { + self.quic_conn.take() + } } impl HttpStream { /// HttpStream constructor. pub fn new(io: T, detail: ConnDetail) -> HttpStream { - HttpStream { detail, stream: io } + HttpStream { + detail, + stream: io, + #[cfg(feature = "http3")] + quic_conn: None, + } + } + + #[cfg(feature = "http3")] + pub fn set_quic_conn(&mut self, conn: QuicConn) { + self.quic_conn = Some(conn); } } diff --git a/ylong_http_client/src/async_impl/interceptor/mod.rs b/ylong_http_client/src/async_impl/interceptor/mod.rs index da02e73..b0f1a4c 100644 --- a/ylong_http_client/src/async_impl/interceptor/mod.rs +++ b/ylong_http_client/src/async_impl/interceptor/mod.rs @@ -31,7 +31,7 @@ pub enum ConnProtocol { Udp, } -/// Tcp connection information. +/// Connection information. #[derive(Clone)] pub struct ConnDetail { /// Transport layer protocol type. diff --git a/ylong_http_client/src/async_impl/ssl_stream/mix.rs b/ylong_http_client/src/async_impl/mix.rs similarity index 74% rename from ylong_http_client/src/async_impl/ssl_stream/mix.rs rename to ylong_http_client/src/async_impl/mix.rs index 4d87042..84a5289 100644 --- a/ylong_http_client/src/async_impl/ssl_stream/mix.rs +++ b/ylong_http_client/src/async_impl/mix.rs @@ -14,21 +14,25 @@ use core::pin::Pin; use core::task::{Context, Poll}; +#[cfg(feature = "http3")] +use ylong_runtime::net::ConnectedUdpSocket; +use ylong_runtime::net::TcpStream; + use crate::async_impl::ssl_stream::AsyncSslStream; use crate::runtime::{AsyncRead, AsyncWrite, ReadBuf}; /// A stream which may be wrapped with TLS. -pub enum MixStream { +pub enum MixStream { /// A raw HTTP stream. - Http(T), + Http(TcpStream), /// An SSL-wrapped HTTP stream. - Https(AsyncSslStream), + Https(AsyncSslStream), + #[cfg(feature = "http3")] + /// A Udp connection + Udp(ConnectedUdpSocket), } -impl AsyncRead for MixStream -where - T: AsyncRead + AsyncWrite + Unpin, -{ +impl AsyncRead for MixStream { // poll_read separately. fn poll_read( mut self: Pin<&mut Self>, @@ -38,14 +42,13 @@ where match &mut *self { MixStream::Http(s) => Pin::new(s).poll_read(cx, buf), MixStream::Https(s) => Pin::new(s).poll_read(cx, buf), + #[cfg(feature = "http3")] + MixStream::Udp(s) => Pin::new(s).poll_recv(cx, buf), } } } -impl AsyncWrite for MixStream -where - T: AsyncRead + AsyncWrite + Unpin, -{ +impl AsyncWrite for MixStream { // poll_write separately. fn poll_write( mut self: Pin<&mut Self>, @@ -55,6 +58,8 @@ where match &mut *self { MixStream::Http(s) => Pin::new(s).poll_write(ctx, buf), MixStream::Https(s) => Pin::new(s).poll_write(ctx, buf), + #[cfg(feature = "http3")] + MixStream::Udp(s) => Pin::new(s).poll_send(ctx, buf), } } @@ -62,6 +67,8 @@ where match &mut *self { MixStream::Http(s) => Pin::new(s).poll_flush(ctx), MixStream::Https(s) => Pin::new(s).poll_flush(ctx), + #[cfg(feature = "http3")] + MixStream::Udp(_) => Poll::Ready(Ok(())), } } @@ -69,6 +76,8 @@ where match &mut *self { MixStream::Http(s) => Pin::new(s).poll_shutdown(ctx), MixStream::Https(s) => Pin::new(s).poll_shutdown(ctx), + #[cfg(feature = "http3")] + MixStream::Udp(_) => Poll::Ready(Ok(())), } } } diff --git a/ylong_http_client/src/async_impl/mod.rs b/ylong_http_client/src/async_impl/mod.rs index 07ec4ad..9e4c6b1 100644 --- a/ylong_http_client/src/async_impl/mod.rs +++ b/ylong_http_client/src/async_impl/mod.rs @@ -36,14 +36,21 @@ mod uploader; #[cfg(feature = "__tls")] mod ssl_stream; +#[cfg(feature = "__tls")] +pub(crate) mod mix; + pub(crate) mod conn; pub(crate) mod pool; +#[cfg(feature = "http3")] +mod quic; pub use client::ClientBuilder; -pub use connector::{Connector, HttpConnector}; +pub use connector::{ConnInfo, Connector, HttpConnector}; pub use downloader::{DownloadOperator, Downloader, DownloaderBuilder}; pub use http_body::HttpBody; pub use interceptor::{ConnDetail, ConnProtocol, Interceptor}; +#[cfg(feature = "http3")] +pub use quic::QuicConn; pub use request::{Body, PercentEncoder, Request, RequestBuilder}; pub use response::Response; pub use uploader::{UploadOperator, Uploader, UploaderBuilder}; diff --git a/ylong_http_client/src/async_impl/pool.rs b/ylong_http_client/src/async_impl/pool.rs index 58befb5..2ed4317 100644 --- a/ylong_http_client/src/async_impl/pool.rs +++ b/ylong_http_client/src/async_impl/pool.rs @@ -14,22 +14,36 @@ use std::mem::take; use std::sync::{Arc, Mutex}; +#[cfg(feature = "http3")] +use ylong_http::request::uri::Authority; #[cfg(feature = "http2")] use ylong_http::request::uri::Scheme; use ylong_http::request::uri::Uri; use crate::async_impl::connector::ConnInfo; +#[cfg(feature = "http3")] +use crate::async_impl::quic::QuicConn; use crate::async_impl::Connector; +#[cfg(feature = "http3")] +use crate::async_impl::Response; use crate::error::HttpClientError; use crate::runtime::{AsyncRead, AsyncWrite}; +#[cfg(feature = "http3")] +use crate::util::alt_svc::{AltService, AltServiceMap}; #[cfg(feature = "http2")] use crate::util::config::H2Config; +#[cfg(feature = "http3")] +use crate::util::config::H3Config; use crate::util::config::{HttpConfig, HttpVersion}; use crate::util::dispatcher::{Conn, ConnDispatcher, Dispatcher}; use crate::util::pool::{Pool, PoolKey}; +#[cfg(feature = "http3")] +use crate::util::request::RequestArc; pub(crate) struct ConnPool { pool: Pool>, + #[cfg(feature = "http3")] + alt_svcs: AltServiceMap, connector: Arc, config: HttpConfig, } @@ -38,6 +52,8 @@ impl ConnPool { pub(crate) fn new(config: HttpConfig, connector: C) -> Self { Self { pool: Pool::new(), + #[cfg(feature = "http3")] + alt_svcs: AltServiceMap::new(), connector: Arc::new(connector), config, } @@ -49,17 +65,33 @@ impl ConnPool { uri.authority().unwrap().clone(), ); + #[cfg(feature = "http3")] + let alt_svc = self.alt_svcs.get_alt_svcs(&key); + self.pool .get(key, Conns::new) - .conn(self.config.clone(), self.connector.clone(), uri) + .conn( + self.config.clone(), + self.connector.clone(), + uri, + #[cfg(feature = "http3")] + alt_svc, + ) .await } + + #[cfg(feature = "http3")] + pub(crate) fn set_alt_svcs(&self, request: RequestArc, response: &Response) { + self.alt_svcs.set_alt_svcs(request, response); + } } pub(crate) struct Conns { list: Arc>>>, #[cfg(feature = "http2")] h2_conn: Arc>>>, + #[cfg(feature = "http3")] + h3_conn: Arc>>>, } impl Conns { @@ -69,8 +101,13 @@ impl Conns { #[cfg(feature = "http2")] h2_conn: Arc::new(crate::runtime::AsyncMutex::new(Vec::with_capacity(1))), + + #[cfg(feature = "http3")] + h3_conn: Arc::new(crate::runtime::AsyncMutex::new(Vec::with_capacity(1))), } } + + // fn get_alt_svcs } impl Clone for Conns { @@ -80,6 +117,9 @@ impl Clone for Conns { #[cfg(feature = "http2")] h2_conn: self.h2_conn.clone(), + + #[cfg(feature = "http3")] + h3_conn: self.h3_conn.clone(), } } } @@ -90,16 +130,27 @@ impl Conns config: HttpConfig, connector: Arc, url: &Uri, + #[cfg(feature = "http3")] alt_svc: Option>, ) -> Result, HttpClientError> where C: Connector, { match config.version { + #[cfg(feature = "http3")] + HttpVersion::Http3 => self.conn_h3(connector, url, config.http3_config).await, #[cfg(feature = "http2")] HttpVersion::Http2 => self.conn_h2(connector, url, config.http2_config).await, #[cfg(feature = "http1_1")] HttpVersion::Http1 => self.conn_h1(connector, url).await, HttpVersion::Negotiate => { + #[cfg(feature = "http3")] + if let Some(conn) = self + .conn_alt_svc(&connector, url, alt_svc, config.http3_config) + .await + { + return Ok(conn); + } + #[cfg(all(feature = "http1_1", not(feature = "http2")))] return self.conn_h1(connector, url).await; @@ -118,7 +169,7 @@ impl Conns if let Some(conn) = self.exist_h1_conn() { return Ok(conn); } - let dispatcher = ConnDispatcher::http1(connector.connect(url).await?); + let dispatcher = ConnDispatcher::http1(connector.connect(url, HttpVersion::Http1).await?); Ok(self.dispatch_h1_conn(dispatcher)) } @@ -140,7 +191,7 @@ impl Conns if let Some(conn) = Self::exist_h2_conn(&mut lock) { return Ok(conn); } - let stream = connector.connect(url).await?; + let stream = connector.connect(url, HttpVersion::Http2).await?; let details = stream.conn_detail(); let tls = if let Some(scheme) = url.scheme() { *scheme == Scheme::HTTPS @@ -158,12 +209,36 @@ impl Conns Ok(Self::dispatch_h2_conn(config, stream, &mut lock)) } + #[cfg(feature = "http3")] + async fn conn_h3( + &self, + connector: Arc, + url: &Uri, + config: H3Config, + ) -> Result, HttpClientError> + where + C: Connector, + { + let mut lock = self.h3_conn.lock().await; + + if let Some(conn) = Self::exist_h3_conn(&mut lock) { + return Ok(conn); + } + let mut stream = connector.connect(url, HttpVersion::Http3).await?; + let quic_conn = stream.quic_conn().ok_or(HttpClientError::from_str( + crate::ErrorKind::Connect, + "QUIC connect failed", + ))?; + + Ok(Self::dispatch_h3_conn(config, stream, quic_conn, &mut lock)) + } + #[cfg(all(feature = "http2", feature = "http1_1"))] async fn conn_negotiate( &self, connector: Arc, url: &Uri, - config: H2Config, + h2_config: H2Config, ) -> Result, HttpClientError> where C: Connector, @@ -171,7 +246,6 @@ impl Conns match *url.scheme().unwrap() { Scheme::HTTPS => { let mut lock = self.h2_conn.lock().await; - if let Some(conn) = Self::exist_h2_conn(&mut lock) { return Ok(conn); } @@ -180,7 +254,7 @@ impl Conns return Ok(conn); } - let stream = connector.connect(url).await?; + let stream = connector.connect(url, HttpVersion::Negotiate).await?; let details = stream.conn_detail(); let protocol = if let Some(bytes) = details.alpn() { @@ -194,7 +268,7 @@ impl Conns let dispatcher = ConnDispatcher::http1(stream); Ok(self.dispatch_h1_conn(dispatcher)) } else if protocol == b"h2" { - Ok(Self::dispatch_h2_conn(config, stream, &mut lock)) + Ok(Self::dispatch_h2_conn(h2_config, stream, &mut lock)) } else { err_from_msg!(Connect, "Alpn negotiate a wrong protocol version.") } @@ -203,6 +277,52 @@ impl Conns } } + #[cfg(feature = "http3")] + async fn conn_alt_svc( + &self, + connector: &Arc, + url: &Uri, + alt_svcs: Option>, + h3_config: H3Config, + ) -> Option> + where + C: Connector, + { + let mut lock = self.h3_conn.lock().await; + if let Some(conn) = Self::exist_h3_conn(&mut lock) { + return Some(conn); + } + if let Some(alt_svcs) = alt_svcs { + for alt_svc in alt_svcs { + // only support h3 alt_svc now + if alt_svc.http_version != HttpVersion::Http3 { + continue; + } + let scheme = Scheme::HTTPS; + let host = match alt_svc.host { + Some(ref host) => host.clone(), + None => url.host().cloned().unwrap(), + }; + let port = alt_svc.port.clone(); + let authority = + Authority::from_bytes((host.to_string() + ":" + port.as_str()).as_bytes()) + .ok()?; + let path = url.path().cloned(); + let query = url.query().cloned(); + let alt_url = Uri::from_raw_parts(Some(scheme), Some(authority), path, query); + let mut stream = connector.connect(&alt_url, HttpVersion::Http3).await.ok()?; + let quic_conn = stream.quic_conn().unwrap(); + return Some(Self::dispatch_h3_conn( + h3_config.clone(), + stream, + quic_conn, + &mut lock, + )); + } + } + None + } + fn dispatch_h1_conn(&self, dispatcher: ConnDispatcher) -> Conn { // We must be able to get the `Conn` here. let conn = dispatcher.dispatch().unwrap(); @@ -224,6 +344,19 @@ impl Conns conn } + #[cfg(feature = "http3")] + fn dispatch_h3_conn( + config: H3Config, + stream: S, + quic_connection: QuicConn, + lock: &mut crate::runtime::MutexGuard>>, + ) -> Conn { + let dispatcher = ConnDispatcher::http3(config, stream, quic_connection); + let conn = dispatcher.dispatch().unwrap(); + lock.push(dispatcher); + conn + } + fn exist_h1_conn(&self) -> Option> { let mut list = self.list.lock().unwrap(); let mut conn = None; @@ -247,6 +380,7 @@ impl Conns lock: &mut crate::runtime::MutexGuard>>, ) -> Option> { if let Some(dispatcher) = lock.pop() { + // todo: shutdown and goaway if !dispatcher.is_shutdown() { if let Some(conn) = dispatcher.dispatch() { lock.push(dispatcher); @@ -256,4 +390,24 @@ impl Conns } None } + + #[cfg(feature = "http3")] + fn exist_h3_conn( + lock: &mut crate::runtime::MutexGuard>>, + ) -> Option> { + if let Some(dispatcher) = lock.pop() { + if dispatcher.is_shutdown() { + return None; + } + if !dispatcher.is_goaway() { + if let Some(conn) = dispatcher.dispatch() { + lock.push(dispatcher); + return Some(conn); + } + } + // Not all requests have been processed yet + lock.push(dispatcher); + } + None + } } diff --git a/ylong_http_client/src/async_impl/quic/mod.rs b/ylong_http_client/src/async_impl/quic/mod.rs new file mode 100644 index 0000000..223a14c --- /dev/null +++ b/ylong_http_client/src/async_impl/quic/mod.rs @@ -0,0 +1,294 @@ +// Copyright (c) 2023 Huawei Device Co., Ltd. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! `ConnDetail` trait and `HttpStream` implementation. + +use std::ffi::c_void; +use std::net::SocketAddr; +use std::ops::{Deref, DerefMut}; +use std::ptr; + +use libc::{ + in6_addr, in_addr, sa_family_t, size_t, sockaddr, sockaddr_in, sockaddr_in6, sockaddr_storage, + socklen_t, AF_INET, AF_INET6, +}; +use ylong_runtime::fastrand::fast_random; +use ylong_runtime::time::timeout; + +use crate::async_impl::connector::ConnInfo; +use crate::c_openssl::ssl::verify_server_cert; +use crate::runtime::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use crate::util::c_openssl::ssl::Ssl; +use crate::{ErrorKind, HttpClientError, TlsConfig}; + +const MAX_DATAGRAM_SIZE: usize = 1350; +const UDP_BUF_SIZE: usize = 65535; +const MAX_STREAM_DATA: u64 = 1_000_000; +const MAX_TOTAL_DATA: u64 = 10_000_000; +const MAX_STREAM_NUM: u64 = 100; +const MAX_IDLE_TIME: u64 = 5000; + +pub struct QuicConn { + inner: quiche::Connection, +} + +impl QuicConn { + fn quic_config() -> Result { + let mut config = quiche::Config::new(quiche::PROTOCOL_VERSION)?; + config.verify_peer(true); + config.set_application_protos(quiche::h3::APPLICATION_PROTOCOL)?; + config.set_max_idle_timeout(MAX_IDLE_TIME); + config.set_max_recv_udp_payload_size(MAX_DATAGRAM_SIZE); + config.set_max_send_udp_payload_size(MAX_DATAGRAM_SIZE); + config.set_initial_max_data(MAX_TOTAL_DATA); + config.set_initial_max_stream_data_bidi_local(MAX_STREAM_DATA); + config.set_initial_max_stream_data_bidi_remote(MAX_STREAM_DATA); + config.set_initial_max_stream_data_uni(MAX_STREAM_DATA); + config.set_initial_max_streams_bidi(MAX_STREAM_NUM); + config.set_initial_max_streams_uni(MAX_STREAM_NUM); + config.set_disable_active_migration(true); + Ok(config) + } + + pub(crate) async fn connect( + stream: &mut S, + tls_config: &TlsConfig, + host: &str, + ) -> Result + where + S: AsyncRead + AsyncWrite + ConnInfo + Unpin + Sync + Send + 'static, + { + let config = Self::quic_config() + .map_err(|_| HttpClientError::from_str(ErrorKind::Connect, "Quic init error"))?; + // Generate a random source connection ID for the connection. + let mut scid = [0; quiche::MAX_CONN_ID_LEN]; + for byte in scid.iter_mut() { + *byte = fast_random() as u8; + } + let scid = quiche::ConnectionId::from_ref(&scid); + + let local = stream.conn_detail().local(); + let peer = stream.conn_detail().peer(); + let mut c_local: sockaddr_storage = unsafe { std::mem::zeroed() }; + let c_local_size = Self::std_addr_to_c(&local, &mut c_local); + let mut c_peer: sockaddr_storage = unsafe { std::mem::zeroed() }; + let c_peer_size = Self::std_addr_to_c(&peer, &mut c_peer); + let mut new_ssl = tls_config.ssl_new(host).unwrap().into_inner(); + + let conn = unsafe { + quiche_conn_new_with_tls( + scid.as_ptr(), + scid.len() as size_t, + ptr::null_mut(), + 0, + &c_local as *const _ as *const sockaddr, + c_local_size, + &c_peer as *const _ as *const sockaddr, + c_peer_size, + &config as *const _ as *const c_void, + new_ssl.get_raw_ptr() as *mut c_void, + false, + ) as *mut quiche::Connection + }; + let mut conn = QuicConn { + inner: unsafe { *Box::from_raw(conn) }, + }; + if let Err(e) = conn.connect_inner(stream, &mut new_ssl, tls_config).await { + std::mem::forget(new_ssl); + return Err(e); + } + std::mem::forget(new_ssl); + if conn.is_established() { + Ok(conn) + } else { + Err(HttpClientError::from_str( + ErrorKind::Connect, + "Quic connect error", + )) + } + } + + async fn connect_inner( + &mut self, + stream: &mut S, + ssl: &mut Ssl, + tls_config: &TlsConfig, + ) -> Result<(), HttpClientError> + where + S: AsyncRead + AsyncWrite + ConnInfo + Unpin + Sync + Send + 'static, + { + let mut buf = [0; UDP_BUF_SIZE]; + let mut out = [0; MAX_DATAGRAM_SIZE]; + let (write, _send_info) = self.send(&mut out).expect("initial send failed"); + let mut e: Result<(), HttpClientError> = Ok(()); + stream + .write_all(&out[..write]) + .await + .map_err(|e| HttpClientError::from_io_error(crate::ErrorKind::Connect, e))?; + loop { + self.conn_recv(stream, &mut buf).await?; + + if self.is_closed() { + break; + } + if self.is_established() { + let Some(key) = tls_config.pinning_host_match(stream.conn_detail().addr()) else { + break; + }; + + if verify_server_cert(ssl.get_raw_ptr(), key.as_str()).is_ok() { + return Ok(()); + } + e = Err(HttpClientError::from_str( + ErrorKind::Connect, + "verify server cert failed", + )); + if let Err(quiche::Error::Done) = + self.close(false, 0x1, b"verify server cert failed") + { + return e; + } + } + + loop { + let (write, _send_info) = match self.send(&mut out) { + Ok(v) => v, + Err(quiche::Error::Done) => { + break; + } + Err(err) => { + if e.is_ok() { + e = Err(HttpClientError::from_error(ErrorKind::Connect, err)); + } + self.close(false, 0x1, b"fail").ok(); + break; + } + }; + stream + .write_all(&out[..write]) + .await + .map_err(|e| HttpClientError::from_io_error(crate::ErrorKind::Connect, e))?; + } + } + e + } + + async fn conn_recv(&mut self, stream: &mut S, buf: &mut [u8]) -> Result<(), HttpClientError> + where + S: AsyncRead + AsyncWrite + ConnInfo + Unpin + Sync + Send + 'static, + { + let recv_info = quiche::RecvInfo { + to: stream.conn_detail().local(), + from: stream.conn_detail().peer(), + }; + let mut recv_size = 0; + let mut len = 0; + loop { + if len != 0 && recv_size != len { + match self.recv(&mut buf[recv_size..len], recv_info) { + Ok(size) => { + recv_size += size; + if recv_size == len { + return Ok(()); + } else { + continue; + } + } + Err(quiche::Error::Done) => { + return Ok(()); + } + Err(e) => { + return Err(HttpClientError::from_error(ErrorKind::Connect, e)); + } + } + } + len = match self.timeout() { + Some(dur) => { + if let Ok(res) = timeout(dur, stream.read(buf)).await { + res + } else { + self.on_timeout(); + return Ok(()); + } + } + None => stream.read(buf).await, + } + .map_err(|e| HttpClientError::from_io_error(crate::ErrorKind::Connect, e))?; + } + } + + fn std_addr_to_c(addr: &SocketAddr, c_addr: &mut sockaddr_storage) -> socklen_t { + let sin_port = addr.port().to_be(); + + match addr { + SocketAddr::V4(addr) => unsafe { + let sa_len = std::mem::size_of::(); + let c_addr_in = c_addr as *mut _ as *mut sockaddr_in; + let s_addr = u32::from_ne_bytes(addr.ip().octets()); + let sin_addr = in_addr { s_addr }; + *c_addr_in = sockaddr_in { + sin_family: AF_INET as sa_family_t, + sin_addr, + sin_port, + sin_zero: std::mem::zeroed(), + }; + sa_len as socklen_t + }, + SocketAddr::V6(addr) => unsafe { + let sa_len = std::mem::size_of::(); + let c_addr_in6 = c_addr as *mut _ as *mut sockaddr_in6; + let sin6_addr = in6_addr { + s6_addr: addr.ip().octets(), + }; + *c_addr_in6 = sockaddr_in6 { + sin6_family: AF_INET6 as sa_family_t, + sin6_addr, + sin6_port: sin_port, + sin6_flowinfo: addr.flowinfo(), + sin6_scope_id: addr.scope_id(), + }; + sa_len as socklen_t + }, + } + } +} + +impl Deref for QuicConn { + type Target = quiche::Connection; + + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +impl DerefMut for QuicConn { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.inner + } +} + +extern "C" { + pub(crate) fn quiche_conn_new_with_tls( + scid: *const u8, + scid_len: size_t, + odcid: *const u8, + odcid_len: size_t, + local: *const sockaddr, + local_len: socklen_t, + peer: *const sockaddr, + peer_len: socklen_t, + config: *const c_void, + ssl: *mut c_void, + is_server: bool, + ) -> *mut c_void; +} diff --git a/ylong_http_client/src/async_impl/ssl_stream/mod.rs b/ylong_http_client/src/async_impl/ssl_stream/mod.rs index 5088776..3728a35 100644 --- a/ylong_http_client/src/async_impl/ssl_stream/mod.rs +++ b/ylong_http_client/src/async_impl/ssl_stream/mod.rs @@ -11,12 +11,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -#[cfg(feature = "__c_openssl")] +#[cfg(feature = "__tls")] mod c_ssl_stream; -mod mix; mod wrapper; -#[cfg(feature = "__c_openssl")] +#[cfg(feature = "__tls")] pub use c_ssl_stream::AsyncSslStream; -pub use mix::MixStream; pub(crate) use wrapper::{check_io_to_poll, Wrapper}; diff --git a/ylong_http_client/src/error.rs b/ylong_http_client/src/error.rs index 47152ae..0517c07 100644 --- a/ylong_http_client/src/error.rs +++ b/ylong_http_client/src/error.rs @@ -127,7 +127,7 @@ impl HttpClientError { /// /// assert!(!HttpClientError::user_aborted().is_tls_error()) /// ``` - #[cfg(feature = "__c_openssl")] + #[cfg(feature = "__tls")] pub fn is_tls_error(&self) -> bool { matches!(self.cause, Cause::Tls(_)) } @@ -144,7 +144,7 @@ impl HttpClientError { } } - #[cfg(feature = "__c_openssl")] + #[cfg(feature = "__tls")] pub(crate) fn from_tls_error(kind: ErrorKind, err: T) -> Self where T: Into>, @@ -273,7 +273,7 @@ impl ErrorKind { pub(crate) enum Cause { NoReason, Dns(io::Error), - #[cfg(feature = "__c_openssl")] + #[cfg(feature = "__tls")] Tls(Box), Io(io::Error), Msg(&'static str), @@ -285,7 +285,7 @@ impl Debug for Cause { match self { Self::NoReason => write!(f, "No reason"), Self::Dns(err) => Debug::fmt(err, f), - #[cfg(feature = "__c_openssl")] + #[cfg(feature = "__tls")] Self::Tls(err) => Debug::fmt(err, f), Self::Io(err) => Debug::fmt(err, f), Self::Msg(msg) => write!(f, "{}", msg), @@ -299,7 +299,7 @@ impl Display for Cause { match self { Self::NoReason => write!(f, "No reason"), Self::Dns(err) => Display::fmt(err, f), - #[cfg(feature = "__c_openssl")] + #[cfg(feature = "__tls")] Self::Tls(err) => Display::fmt(err, f), Self::Io(err) => Display::fmt(err, f), Self::Msg(msg) => write!(f, "{}", msg), diff --git a/ylong_http_client/src/sync_impl/client.rs b/ylong_http_client/src/sync_impl/client.rs index f9f65ea..87b23ab 100644 --- a/ylong_http_client/src/sync_impl/client.rs +++ b/ylong_http_client/src/sync_impl/client.rs @@ -424,27 +424,6 @@ impl ClientBuilder { self } - /// Sets the list of supported ciphers for the `TLSv1.3` protocol. - /// - /// The format consists of TLSv1.3 cipher suite names separated by `:` - /// characters in order of preference. - /// - /// Requires `OpenSSL 1.1.1` or `LibreSSL 3.4.0` or newer. - /// - /// # Examples - /// - /// ``` - /// use ylong_http_client::sync_impl::ClientBuilder; - /// - /// let builder = ClientBuilder::new().tls_cipher_suites( - /// "DEFAULT:!aNULL:!eNULL:!MD5:!3DES:!DES:!RC4:!IDEA:!SEED:!aDSS:!SRP:!PSK", - /// ); - /// ``` - pub fn tls_cipher_suites(mut self, list: &str) -> Self { - self.tls = self.tls.cipher_suites(list); - self - } - /// Controls the use of built-in system certificates during certificate /// validation. Default to `true` -- uses built-in system certs. /// diff --git a/ylong_http_client/src/sync_impl/ssl_stream.rs b/ylong_http_client/src/sync_impl/ssl_stream.rs index 2a3e3c3..a809c21 100644 --- a/ylong_http_client/src/sync_impl/ssl_stream.rs +++ b/ylong_http_client/src/sync_impl/ssl_stream.rs @@ -13,7 +13,7 @@ use std::io::{Read, Write}; -#[cfg(feature = "__c_openssl")] +#[cfg(feature = "__tls")] use crate::util::c_openssl::ssl::SslStream; /// A stream which may be wrapped with TLS. diff --git a/ylong_http_client/src/util/alt_svc.rs b/ylong_http_client/src/util/alt_svc.rs new file mode 100644 index 0000000..30243a0 --- /dev/null +++ b/ylong_http_client/src/util/alt_svc.rs @@ -0,0 +1,138 @@ +// Copyright (c) 2023 Huawei Device Co., Ltd. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::HashMap; +use std::str; +use std::str::FromStr; +use std::sync::{Arc, Mutex}; +use std::time::{Duration, Instant}; + +use ylong_http::request::uri::{Host, Port}; + +use crate::async_impl::Response; +use crate::util::config::HttpVersion; +use crate::util::pool::PoolKey; +use crate::util::request::RequestArc; + +const DEFAULT_MAX_AGE: u64 = 24 * 60 * 60; + +#[derive(Clone)] +pub(crate) struct AltService { + pub(crate) http_version: HttpVersion, + // todo: use this later + #[allow(unused)] + pub(crate) src_host: Host, + pub(crate) host: Option, + pub(crate) port: Port, + pub(crate) lifetime: Instant, +} + +pub(crate) struct AltServiceMap { + inner: Arc>>>, +} + +impl AltServiceMap { + pub(crate) fn get_alt_svcs(&self, key: &PoolKey) -> Option> { + let mut lock = self.inner.lock().unwrap(); + let vec = lock.get_mut(key)?; + vec.retain(|alt_scv| alt_scv.lifetime >= Instant::now()); + Some(vec.clone()) + } + + fn parse_alt_svc(src_host: &Host, values: &[u8]) -> Option { + // The alt_value/parameters are divided by ';' + let mut value_it = values.split(|c| *c == b';'); + // the first value_it is alpn="[host]:port" + let alternative = value_it.next()?; + let mut words = alternative.split(|c| *c == b'='); + + let http_version = words.next()?.try_into().ok()?; + let mut host_port = words.next()?; + host_port = &host_port[1..host_port.len() - 1]; + let index = host_port.iter().position(|&x| x == b':')?; + let (host, port) = if index == 0 { + ( + None, + Port::from_str(str::from_utf8(&host_port[1..]).ok()?).ok()?, + ) + } else { + ( + Some(Host::from_str(str::from_utf8(&host_port[..index]).ok()?).ok()?), + Port::from_str(str::from_utf8(&host_port[(index + 1)..]).ok()?).ok()?, + ) + }; + + let mut seconds = DEFAULT_MAX_AGE; + + for para in value_it { + let para = str::from_utf8(para).ok()?.trim().as_bytes(); + // parameter: token "=" ( token / quoted-string ) + let mut para_it = para.split(|c| *c == b'='); + // only support ma now + if para_it.next()? == b"ma" { + let para = str::from_utf8(para_it.next()?).ok()?; + seconds = para.parse::().ok()?; + break; + } + } + + Some(AltService { + http_version, + src_host: src_host.clone(), + host, + port, + lifetime: Instant::now().checked_add(Duration::from_secs(seconds))?, + }) + } + + pub(crate) fn set_alt_svcs(&self, mut request: RequestArc, response: &Response) { + let mut lock = self.inner.lock().unwrap(); + let uri = request.ref_mut().uri(); + let Some(scheme) = uri.scheme() else { + return; + }; + let Some(authority) = uri.authority() else { + return; + }; + let Some(host) = uri.host() else { + return; + }; + let key = PoolKey::new(scheme.clone(), authority.clone()); + match response.headers().get("Alt-Svc") { + None => {} + Some(values) => { + let mut new_alt_svcs = Vec::new(); + for value in values.iter() { + let slice = value.as_slice(); + if slice == "clear".as_bytes() { + lock.remove(&key); + return; + } + // Alt_Svcs are divided by ',' + for slice in slice.split(|c| *c == b',') { + if let Some(alt_svc) = Self::parse_alt_svc(host, slice) { + new_alt_svcs.push(alt_svc); + } + } + } + lock.insert(key, new_alt_svcs); + } + } + } + + pub(crate) fn new() -> Self { + Self { + inner: Arc::new(Mutex::new(HashMap::new())), + } + } +} diff --git a/ylong_http_client/src/util/c_openssl/adapter.rs b/ylong_http_client/src/util/c_openssl/adapter.rs index b28f116..d2d5820 100644 --- a/ylong_http_client/src/util/c_openssl/adapter.rs +++ b/ylong_http_client/src/util/c_openssl/adapter.rs @@ -46,7 +46,6 @@ pub struct TlsConfigBuilder { verify_hostname: bool, certs_list: Vec, pins: Option, - #[cfg(feature = "c_openssl_3_0")] paths_list: Vec, } @@ -68,7 +67,6 @@ impl TlsConfigBuilder { verify_hostname: true, certs_list: vec![], pins: None, - #[cfg(feature = "c_openssl_3_0")] paths_list: vec![], } } @@ -136,8 +134,6 @@ impl TlsConfigBuilder { /// Sets the list of supported ciphers for protocols before `TLSv1.3`. /// - /// The `set_ciphersuites` method controls the cipher suites for `TLSv1.3`. - /// /// See [`ciphers`] for details on the format. /// /// [`ciphers`]: https://www.openssl.org/docs/man1.1.0/apps/ciphers.html @@ -157,31 +153,6 @@ impl TlsConfigBuilder { self } - /// Sets the list of supported ciphers for the `TLSv1.3` protocol. - /// - /// The `set_cipher_list` method controls the cipher suites for protocols - /// before `TLSv1.3`. - /// - /// The format consists of TLSv1.3 cipher suite names separated by `:` - /// characters in order of preference. - /// - /// Requires `OpenSSL 1.1.1` or `LibreSSL 3.4.0` or newer. - /// - /// # Examples - /// - /// ``` - /// use ylong_http_client::TlsConfigBuilder; - /// - /// let builder = TlsConfigBuilder::new() - /// .cipher_suites("DEFAULT:!aNULL:!eNULL:!MD5:!3DES:!DES:!RC4:!IDEA:!SEED:!aDSS:!SRP:!PSK"); - /// ``` - pub fn cipher_suites(mut self, list: &str) -> Self { - self.inner = self - .inner - .and_then(|mut builder| builder.set_cipher_suites(list).map(|_| builder)); - self - } - /// Loads a leaf certificate from a file. /// /// Only a single certificate will be loaded - use `add_extra_chain_cert` to @@ -250,7 +221,6 @@ impl TlsConfigBuilder { /// let builder = TlsConfigBuilder::new().add_path_certificates(path); /// # } /// ``` - #[cfg(feature = "c_openssl_3_0")] pub fn add_path_certificates(mut self, path: String) -> Self { self.paths_list.push(path); self @@ -260,7 +230,7 @@ impl TlsConfigBuilder { // Negotiation (ALPN). // // Requires OpenSSL 1.0.2 or LibreSSL 2.6.1 or newer. - #[cfg(feature = "http2")] + #[cfg(any(feature = "http2", feature = "http3"))] pub(crate) fn alpn_protos(mut self, protocols: &[u8]) -> Self { self.inner = self .inner @@ -401,7 +371,6 @@ impl TlsConfigBuilder { }); } - #[cfg(feature = "c_openssl_3_0")] for path in self.paths_list { self.inner = self.inner.and_then(|mut builder| { Ok(builder.cert_store_mut()) @@ -647,7 +616,6 @@ pub struct Certificate { #[derive(Clone)] pub(crate) enum CertificateList { CertList(Vec), - #[cfg(feature = "c_openssl_3_0")] PathList(String), } @@ -665,7 +633,6 @@ impl Certificate { } /// Deserializes a list of PEM-formatted certificates. - #[cfg(feature = "c_openssl_3_0")] pub fn from_path(path: &str) -> Result { Ok(Certificate { inner: CertificateList::PathList(path.to_string()), @@ -695,19 +662,6 @@ mod ut_openssl_adapter { assert!(builder.ca_file("folder/ca.crt").build().is_err()); } - /// UT test cases for `TlsConfigBuilder::new`. - /// - /// # Brief - /// 1. Creates a `TlsConfigBuilder` by calling `TlsConfigBuilder::new`. - /// 2. Calls `set_cipher_suites`. - /// 3. Provides an invalid path as argument. - /// 4. Checks if the result is as expected. - #[test] - fn ut_set_cipher_suites() { - let builder = TlsConfigBuilder::new().cipher_suites("INVALID STRING"); - assert!(builder.build().is_err()); - } - /// UT test cases for `TlsConfigBuilder::set_max_proto_version`. /// /// # Brief @@ -791,9 +745,6 @@ mod ut_openssl_adapter { fn ut_add_root_certificates() { let certificate = Certificate::from_pem(include_bytes!("../../../tests/file/root-ca.pem")) .expect("Sets certs error."); - #[cfg(feature = "c_openssl_1_1")] - let CertificateList::CertList(certs) = certificate.inner; - #[cfg(feature = "c_openssl_3_0")] let certs = match certificate.inner { CertificateList::CertList(c) => c, CertificateList::PathList(_) => vec![], diff --git a/ylong_http_client/src/util/c_openssl/error.rs b/ylong_http_client/src/util/c_openssl/error.rs index 3cafcd8..9622958 100644 --- a/ylong_http_client/src/util/c_openssl/error.rs +++ b/ylong_http_client/src/util/c_openssl/error.rs @@ -19,14 +19,14 @@ use std::error::Error; use std::ffi::CString; use std::fmt; -#[cfg(feature = "c_openssl_1_1")] +#[cfg(any(feature = "c_openssl_1_1", feature = "c_boringssl"))] use libc::c_char; use libc::{c_int, c_ulong}; use super::ssl_init; #[cfg(feature = "c_openssl_3_0")] use crate::util::c_openssl::ffi::err::ERR_get_error_all; -#[cfg(feature = "c_openssl_1_1")] +#[cfg(any(feature = "c_openssl_1_1", feature = "c_boringssl"))] use crate::util::c_openssl::ffi::err::{ERR_func_error_string, ERR_get_error_line_data}; use crate::util::c_openssl::ffi::err::{ERR_lib_error_string, ERR_reason_error_string}; @@ -37,7 +37,7 @@ const ERR_TXT_STRING: c_int = 0x02; #[derive(Debug)] pub(crate) struct StackError { code: c_ulong, - #[cfg(feature = "c_openssl_1_1")] + #[cfg(any(feature = "c_openssl_1_1", feature = "c_boringssl"))] file: *const c_char, #[cfg(feature = "c_openssl_3_0")] file: CString, @@ -51,7 +51,7 @@ impl Clone for StackError { fn clone(&self) -> Self { Self { code: self.code, - #[cfg(feature = "c_openssl_1_1")] + #[cfg(any(feature = "c_openssl_1_1", feature = "c_boringssl"))] file: self.file, #[cfg(feature = "c_openssl_3_0")] file: self.file.clone(), @@ -76,7 +76,7 @@ impl StackError { let mut data = ptr::null(); let mut flags = 0; - #[cfg(feature = "c_openssl_1_1")] + #[cfg(any(feature = "c_openssl_1_1", feature = "c_boringssl"))] match ERR_get_error_line_data(&mut file, &mut line, &mut data, &mut flags) { 0 => None, code => { @@ -155,7 +155,7 @@ impl fmt::Display for StackError { } } - #[cfg(feature = "c_openssl_1_1")] + #[cfg(any(feature = "c_openssl_1_1", feature = "c_boringssl"))] { let func_error = unsafe { ERR_func_error_string(self.code) }; if !func_error.is_null() { @@ -241,7 +241,7 @@ const fn error_system_error(code: c_ulong) -> bool { } pub(crate) const fn error_get_lib(code: c_ulong) -> c_int { - #[cfg(feature = "c_openssl_1_1")] + #[cfg(any(feature = "c_openssl_1_1", feature = "c_boringssl"))] return ((code >> 24) & 0x0FF) as c_int; #[cfg(feature = "c_openssl_3_0")] @@ -251,7 +251,7 @@ pub(crate) const fn error_get_lib(code: c_ulong) -> c_int { #[allow(unused_variables)] const fn error_get_func(code: c_ulong) -> c_int { - #[cfg(feature = "c_openssl_1_1")] + #[cfg(any(feature = "c_openssl_1_1", feature = "c_boringssl"))] return ((code >> 12) & 0xFFF) as c_int; #[cfg(feature = "c_openssl_3_0")] @@ -259,7 +259,7 @@ const fn error_get_func(code: c_ulong) -> c_int { } pub(crate) const fn error_get_reason(code: c_ulong) -> c_int { - #[cfg(feature = "c_openssl_1_1")] + #[cfg(any(feature = "c_openssl_1_1", feature = "c_boringssl"))] return (code & 0xFFF) as c_int; #[cfg(feature = "c_openssl_3_0")] diff --git a/ylong_http_client/src/util/c_openssl/ffi/err.rs b/ylong_http_client/src/util/c_openssl/ffi/err.rs index 093ae08..2693992 100644 --- a/ylong_http_client/src/util/c_openssl/ffi/err.rs +++ b/ylong_http_client/src/util/c_openssl/ffi/err.rs @@ -37,7 +37,7 @@ extern "C" { flags: *mut c_int, ) -> c_ulong; - #[cfg(feature = "c_openssl_1_1")] + #[cfg(any(feature = "c_openssl_1_1", feature = "c_boringssl"))] pub(crate) fn ERR_get_error_line_data( file: *mut *const c_char, line: *mut c_int, @@ -45,6 +45,6 @@ extern "C" { flags: *mut c_int, ) -> c_ulong; - #[cfg(feature = "c_openssl_1_1")] + #[cfg(any(feature = "c_openssl_1_1", feature = "c_boringssl"))] pub(crate) fn ERR_func_error_string(err: c_ulong) -> *const c_char; } diff --git a/ylong_http_client/src/util/c_openssl/ffi/mod.rs b/ylong_http_client/src/util/c_openssl/ffi/mod.rs index 830257e..d7373e3 100644 --- a/ylong_http_client/src/util/c_openssl/ffi/mod.rs +++ b/ylong_http_client/src/util/c_openssl/ffi/mod.rs @@ -18,7 +18,7 @@ pub(crate) mod bio; pub(crate) mod callback; pub(crate) mod err; pub(crate) mod pem; -pub(crate) mod ssl; +pub mod ssl; // todo pub(crate) mod stack; pub(crate) mod x509; diff --git a/ylong_http_client/src/util/c_openssl/ffi/ssl.rs b/ylong_http_client/src/util/c_openssl/ffi/ssl.rs index 8b3ab8b..7f516a0 100644 --- a/ylong_http_client/src/util/c_openssl/ffi/ssl.rs +++ b/ylong_http_client/src/util/c_openssl/ffi/ssl.rs @@ -36,6 +36,7 @@ extern "C" { pub(crate) fn SSL_CTX_up_ref(x: *mut SSL_CTX) -> c_int; /// Internal handling functions for SSL_CTX objects. + #[cfg(feature = "__c_openssl")] pub(crate) fn SSL_CTX_ctrl( ctx: *mut SSL_CTX, cmd: c_int, @@ -55,9 +56,6 @@ extern "C" { /// This function does not impact TLSv1.3 ciphersuites. pub(crate) fn SSL_CTX_set_cipher_list(ssl: *mut SSL_CTX, s: *const c_char) -> c_int; - /// Uses to configure the available TLSv1.3 ciphersuites for ctx. - pub(crate) fn SSL_CTX_set_ciphersuites(ctx: *mut SSL_CTX, str: *const c_char) -> c_int; - /// Loads the first certificate stored in file into ctx. /// The formatting type of the certificate must be specified from the known /// types SSL_FILETYPE_PEM, SSL_FILETYPE_ASN1. @@ -126,6 +124,20 @@ extern "C" { arg: *mut c_void, ); + #[cfg(feature = "c_boringssl")] + pub(crate) fn SSL_CTX_set_min_proto_version( + ctx: *mut SSL_CTX, + version: libc::c_ushort, + ) -> c_int; + + #[cfg(feature = "c_boringssl")] + pub(crate) fn SSL_CTX_set_max_proto_version( + ctx: *mut SSL_CTX, + version: libc::c_ushort, + ) -> c_int; + + #[cfg(feature = "c_boringssl")] + pub(crate) fn SSL_CTX_set1_sigalgs_list(ctx: *mut SSL_CTX, parg: *mut c_void) -> c_int; } /// This is the main SSL/TLS structure which is created by a server or client @@ -159,8 +171,9 @@ extern "C" { #[cfg(feature = "c_openssl_3_0")] pub(crate) fn SSL_get1_peer_certificate(ssl: *const SSL) -> *mut C_X509; + // use 1.1 in boringssl - #[cfg(feature = "c_openssl_1_1")] + #[cfg(any(feature = "c_openssl_1_1", feature = "c_boringssl"))] pub(crate) fn SSL_get_peer_certificate(ssl: *const SSL) -> *mut C_X509; pub(crate) fn SSL_set_bio(ssl: *mut SSL, rbio: *mut BIO, wbio: *mut BIO); @@ -175,12 +188,16 @@ extern "C" { pub(crate) fn SSL_shutdown(ssl: *mut SSL) -> c_int; + #[cfg(feature = "__c_openssl")] pub(crate) fn SSL_ctrl(ssl: *mut SSL, cmd: c_int, larg: c_long, parg: *mut c_void) -> c_long; /// Retrieve an internal pointer to the verification parameters for ssl /// respectively. The returned pointer must not be freed by the calling /// application. pub(crate) fn SSL_get0_param(ssl: *mut SSL) -> *mut X509_VERIFY_PARAM; + + #[cfg(feature = "c_boringssl")] + pub(crate) fn SSL_set_tlsext_host_name(ssl: *mut SSL, name: *mut c_void) -> c_int; } /// This is a dispatch structure describing the internal ssl library diff --git a/ylong_http_client/src/util/c_openssl/ffi/x509.rs b/ylong_http_client/src/util/c_openssl/ffi/x509.rs index 8623a60..2f046aa 100644 --- a/ylong_http_client/src/util/c_openssl/ffi/x509.rs +++ b/ylong_http_client/src/util/c_openssl/ffi/x509.rs @@ -99,6 +99,13 @@ extern "C" { /// storage. #[cfg(feature = "c_openssl_3_0")] pub(crate) fn X509_STORE_load_path(store: *mut X509_STORE, x: *const c_char) -> c_int; + + #[cfg(any(feature = "c_openssl_1_1", feature = "c_boringssl"))] + pub(crate) fn X509_STORE_load_locations( + store: *mut X509_STORE, + file: *const c_char, + dir: *const c_char, + ) -> c_int; } pub(crate) enum X509_STORE_CTX {} diff --git a/ylong_http_client/src/util/c_openssl/mod.rs b/ylong_http_client/src/util/c_openssl/mod.rs index 3188e47..1e0ae58 100644 --- a/ylong_http_client/src/util/c_openssl/mod.rs +++ b/ylong_http_client/src/util/c_openssl/mod.rs @@ -18,10 +18,13 @@ #[macro_use] mod foreign; mod bio; -mod ffi; +pub mod ffi; pub(crate) mod error; pub(crate) mod ssl; + +// todo +#[allow(dead_code)] pub(crate) mod stack; pub(crate) mod x509; diff --git a/ylong_http_client/src/util/c_openssl/ssl/ctx.rs b/ylong_http_client/src/util/c_openssl/ssl/ctx.rs index bd46f66..c23bd61 100644 --- a/ylong_http_client/src/util/c_openssl/ssl/ctx.rs +++ b/ylong_http_client/src/util/c_openssl/ssl/ctx.rs @@ -15,7 +15,7 @@ use core::{fmt, mem, ptr}; use std::ffi::CString; use std::path::Path; -use libc::{c_int, c_long, c_uint, c_void}; +use libc::{c_int, c_uint, c_void}; use super::filetype::SslFiletype; use super::method::SslMethod; @@ -25,18 +25,26 @@ use crate::c_openssl::ffi::ssl::{ }; use crate::c_openssl::x509::{X509Store, X509StoreRef}; use crate::util::c_openssl::error::ErrorStack; +#[cfg(feature = "__c_openssl")] +use crate::util::c_openssl::ffi::ssl::SSL_CTX_ctrl; use crate::util::c_openssl::ffi::ssl::{ - SSL_CTX_ctrl, SSL_CTX_load_verify_locations, SSL_CTX_new, SSL_CTX_set_alpn_protos, - SSL_CTX_set_cert_store, SSL_CTX_set_cert_verify_callback, SSL_CTX_set_cipher_list, - SSL_CTX_set_ciphersuites, SSL_CTX_up_ref, SSL_CTX_use_certificate_chain_file, - SSL_CTX_use_certificate_file, SSL_CTX, + SSL_CTX_load_verify_locations, SSL_CTX_new, SSL_CTX_set_alpn_protos, SSL_CTX_set_cert_store, + SSL_CTX_set_cert_verify_callback, SSL_CTX_set_cipher_list, SSL_CTX_up_ref, + SSL_CTX_use_certificate_chain_file, SSL_CTX_use_certificate_file, SSL_CTX, +}; +#[cfg(feature = "c_boringssl")] +use crate::util::c_openssl::ffi::ssl::{ + SSL_CTX_set1_sigalgs_list, SSL_CTX_set_max_proto_version, SSL_CTX_set_min_proto_version, }; use crate::util::c_openssl::foreign::{Foreign, ForeignRef}; use crate::util::c_openssl::{cert_verify, check_ptr, check_ret, ssl_init}; use crate::util::config::tls::DefaultCertVerifier; +#[cfg(feature = "__c_openssl")] const SSL_CTRL_SET_MIN_PROTO_VERSION: c_int = 123; +#[cfg(feature = "__c_openssl")] const SSL_CTRL_SET_MAX_PROTO_VERSION: c_int = 124; +#[cfg(feature = "__c_openssl")] const SSL_CTRL_SET_SIGALGS_LIST: c_int = 98; foreign_type!( @@ -117,29 +125,42 @@ impl SslContextBuilder { pub(crate) fn set_min_proto_version(&mut self, version: SslVersion) -> Result<(), ErrorStack> { let ptr = self.as_ptr_mut(); - check_ret(unsafe { + #[cfg(feature = "__c_openssl")] + return check_ret(unsafe { SSL_CTX_ctrl( ptr, SSL_CTRL_SET_MIN_PROTO_VERSION, - version.0 as c_long, + version.0 as libc::c_long, ptr::null_mut(), ) } as c_int) - .map(|_| ()) + .map(|_| ()); + + #[cfg(feature = "c_boringssl")] + return check_ret( + unsafe { SSL_CTX_set_min_proto_version(ptr, version.0 as libc::c_ushort) } as c_int, + ) + .map(|_| ()); } pub(crate) fn set_max_proto_version(&mut self, version: SslVersion) -> Result<(), ErrorStack> { let ptr = self.as_ptr_mut(); - check_ret(unsafe { + #[cfg(feature = "__c_openssl")] + return check_ret(unsafe { SSL_CTX_ctrl( ptr, SSL_CTRL_SET_MAX_PROTO_VERSION, - version.0 as c_long, + version.0 as libc::c_long, ptr::null_mut(), ) } as c_int) - .map(|_| ()) + .map(|_| ()); + #[cfg(feature = "c_boringssl")] + return check_ret( + unsafe { SSL_CTX_set_max_proto_version(ptr, version.0 as libc::c_ushort) } as c_int, + ) + .map(|_| ()); } /// Loads trusted root certificates from a file.\ @@ -169,17 +190,6 @@ impl SslContextBuilder { check_ret(unsafe { SSL_CTX_set_cipher_list(ptr, list.as_ptr() as *const _) }).map(|_| ()) } - /// Sets the list of supported ciphers for the `TLSv1.3` protocol. - pub(crate) fn set_cipher_suites(&mut self, list: &str) -> Result<(), ErrorStack> { - let list = match CString::new(list) { - Ok(cstr) => cstr, - Err(_) => return Err(ErrorStack::get()), - }; - let ptr = self.as_ptr_mut(); - - check_ret(unsafe { SSL_CTX_set_ciphersuites(ptr, list.as_ptr() as *const _) }).map(|_| ()) - } - /// Loads a leaf certificate from a file. /// /// Only a single certificate will be loaded - use `add_extra_chain_cert` to @@ -290,17 +300,16 @@ impl SslContextBuilder { // SHA512 DSA (0x0602) const SUPPORT_SIGNATURE_ALGORITHMS: &str = "\ ECDSA+SHA256:ECDSA+SHA384:ECDSA+SHA512:ed25519:\ - ed448:rsa_pss_pss_sha256:rsa_pss_pss_sha384:\ - rsa_pss_pss_sha512:rsa_pss_rsae_sha256:rsa_pss_rsae_sha384:\ - rsa_pss_rsae_sha512:rsa_pkcs1_sha256:rsa_pkcs1_sha384:rsa_pkcs1_sha512:DSA+SHA256:DSA+SHA384:DSA+SHA512"; + rsa_pss_rsae_sha256:rsa_pss_rsae_sha384:\ + rsa_pss_rsae_sha512:rsa_pkcs1_sha256:rsa_pkcs1_sha384:rsa_pkcs1_sha512"; let list = match CString::new(SUPPORT_SIGNATURE_ALGORITHMS) { Ok(cstr) => cstr, Err(_) => return Err(ErrorStack::get()), }; let ptr = self.as_ptr_mut(); - - check_ret(unsafe { + #[cfg(feature = "__c_openssl")] + return check_ret(unsafe { SSL_CTX_ctrl( ptr, SSL_CTRL_SET_SIGALGS_LIST, @@ -308,6 +317,11 @@ impl SslContextBuilder { list.as_ptr() as *const c_void as *mut c_void, ) } as c_int) - .map(|_| ()) + .map(|_| ()); + #[cfg(feature = "c_boringssl")] + return check_ret(unsafe { + SSL_CTX_set1_sigalgs_list(ptr, list.as_ptr() as *const c_void as *mut c_void) + } as c_int) + .map(|_| ()); } } diff --git a/ylong_http_client/src/util/c_openssl/ssl/mod.rs b/ylong_http_client/src/util/c_openssl/ssl/mod.rs index 5e7e0d3..ab0dfd0 100644 --- a/ylong_http_client/src/util/c_openssl/ssl/mod.rs +++ b/ylong_http_client/src/util/c_openssl/ssl/mod.rs @@ -24,5 +24,7 @@ pub(crate) use error::{InternalError, SslError, SslErrorCode}; pub(crate) use filetype::SslFiletype; pub(crate) use method::SslMethod; pub(crate) use ssl_base::{Ssl, SslRef}; +#[cfg(feature = "http3")] +pub(crate) use stream::verify_server_cert; pub(crate) use stream::{MidHandshakeSslStream, ShutdownResult, SslStream}; pub(crate) use version::SslVersion; diff --git a/ylong_http_client/src/util/c_openssl/ssl/ssl_base.rs b/ylong_http_client/src/util/c_openssl/ssl/ssl_base.rs index f9820c7..467af18 100644 --- a/ylong_http_client/src/util/c_openssl/ssl/ssl_base.rs +++ b/ylong_http_client/src/util/c_openssl/ssl/ssl_base.rs @@ -28,7 +28,7 @@ use super::{SslContext, SslErrorCode}; use crate::c_openssl::check_ret; use crate::c_openssl::ffi::bio::BIO; use crate::c_openssl::ffi::ssl::{ - SSL_ctrl, SSL_get0_param, SSL_get_error, SSL_get_rbio, SSL_get_verify_result, SSL_read, + SSL_get0_param, SSL_get_error, SSL_get_rbio, SSL_get_verify_result, SSL_read, SSL_state_string_long, SSL_write, }; use crate::c_openssl::foreign::ForeignRef; @@ -56,6 +56,11 @@ impl Ssl { } } + #[cfg(feature = "http3")] + pub(crate) fn get_raw_ptr(&mut self) -> *mut SSL { + self.as_ptr() + } + /// Client connect to Server. /// only `sync` use. #[cfg(feature = "sync")] @@ -168,14 +173,25 @@ impl fmt::Debug for SslRef { } } +#[cfg(feature = "__c_openssl")] const SSL_CTRL_SET_TLSEXT_HOSTNAME: c_int = 0x37; +#[cfg(feature = "__c_openssl")] const TLSEXT_NAMETYPE_HOST_NAME: c_int = 0x0; unsafe fn ssl_set_tlsext_host_name(s: *mut SSL, name: *mut c_char) -> c_long { - SSL_ctrl( + #[cfg(feature = "__c_openssl")] + use crate::c_openssl::ffi::ssl::SSL_ctrl; + #[cfg(feature = "c_boringssl")] + use crate::c_openssl::ffi::ssl::SSL_set_tlsext_host_name; + + #[cfg(feature = "__c_openssl")] + return SSL_ctrl( s, SSL_CTRL_SET_TLSEXT_HOSTNAME, TLSEXT_NAMETYPE_HOST_NAME as c_long, name as *mut c_void, - ) + ); + + #[cfg(feature = "c_boringssl")] + return SSL_set_tlsext_host_name(s, name as *mut c_void) as c_long; } diff --git a/ylong_http_client/src/util/c_openssl/ssl/stream.rs b/ylong_http_client/src/util/c_openssl/ssl/stream.rs index d0a7d3f..1e1263e 100644 --- a/ylong_http_client/src/util/c_openssl/ssl/stream.rs +++ b/ylong_http_client/src/util/c_openssl/ssl/stream.rs @@ -253,10 +253,10 @@ pub(crate) enum ShutdownResult { } // TODO The SSLError thrown here is meaningless and has no information. -fn verify_server_cert(ssl: *const SSL, pinned_key: &str) -> Result<(), SslError> { +pub(crate) fn verify_server_cert(ssl: *const SSL, pinned_key: &str) -> Result<(), SslError> { #[cfg(feature = "c_openssl_3_0")] use crate::util::c_openssl::ffi::ssl::SSL_get1_peer_certificate; - #[cfg(feature = "c_openssl_1_1")] + #[cfg(any(feature = "c_openssl_1_1", feature = "c_boringssl"))] use crate::util::c_openssl::ffi::ssl::SSL_get_peer_certificate; let certificate = unsafe { @@ -264,7 +264,7 @@ fn verify_server_cert(ssl: *const SSL, pinned_key: &str) -> Result<(), SslError> { SSL_get1_peer_certificate(ssl) } - #[cfg(feature = "c_openssl_1_1")] + #[cfg(any(feature = "c_openssl_1_1", feature = "c_boringssl"))] { SSL_get_peer_certificate(ssl) } diff --git a/ylong_http_client/src/util/c_openssl/x509.rs b/ylong_http_client/src/util/c_openssl/x509.rs index 0846efc..54f4bee 100644 --- a/ylong_http_client/src/util/c_openssl/x509.rs +++ b/ylong_http_client/src/util/c_openssl/x509.rs @@ -12,7 +12,6 @@ // limitations under the License. use core::{ffi, fmt, ptr, str}; -#[cfg(feature = "c_openssl_3_0")] use std::ffi::CString; use std::net::IpAddr; @@ -22,8 +21,6 @@ use super::bio::BioSlice; use super::error::{error_get_lib, error_get_reason, ErrorStack}; use super::ffi::err::{ERR_clear_error, ERR_peek_last_error}; use super::ffi::pem::PEM_read_bio_X509; -#[cfg(feature = "c_openssl_3_0")] -use super::ffi::x509::X509_STORE_load_path; use super::ffi::x509::{ d2i_X509, EVP_PKEY_free, X509_NAME_free, X509_NAME_oneline, X509_PUBKEY_free, X509_STORE_CTX_free, X509_STORE_CTX_get0_cert, X509_STORE_add_cert, X509_STORE_free, @@ -59,6 +56,9 @@ foreign_type! { } const ERR_LIB_PEM: c_int = 9; +#[cfg(feature = "c_boringssl")] +const PEM_R_NO_START_LINE: c_int = 110; +#[cfg(feature = "__c_openssl")] const PEM_R_NO_START_LINE: c_int = 108; impl X509 { @@ -96,7 +96,6 @@ impl X509 { ERR_clear_error(); break; } - return Err(ErrorStack::get()); } else { certs.push(X509(r)); @@ -217,15 +216,27 @@ impl X509StoreRef { check_ret(unsafe { X509_STORE_add_cert(self.as_ptr(), cert.as_ptr()) }).map(|_| ()) } - #[cfg(feature = "c_openssl_3_0")] pub(crate) fn add_path(&mut self, path: String) -> Result<(), ErrorStack> { + #[cfg(any(feature = "c_openssl_1_1", feature = "c_boringssl"))] + use super::ffi::x509::X509_STORE_load_locations; + #[cfg(feature = "c_openssl_3_0")] + use super::ffi::x509::X509_STORE_load_path; + let p_slice: &str = &path; let path = match CString::new(p_slice) { Ok(cstr) => cstr, Err(_) => return Err(ErrorStack::get()), }; - check_ret(unsafe { X509_STORE_load_path(self.as_ptr(), path.as_ptr() as *const _) }) - .map(|_| ()) + #[cfg(feature = "c_openssl_3_0")] + return check_ret(unsafe { + X509_STORE_load_path(self.as_ptr(), path.as_ptr() as *const _) + }) + .map(|_| ()); + #[cfg(any(feature = "c_openssl_1_1", feature = "c_boringssl"))] + return check_ret(unsafe { + X509_STORE_load_locations(self.as_ptr(), ptr::null(), path.as_ptr() as *const _) + }) + .map(|_| ()); } } diff --git a/ylong_http_client/src/util/config/connector.rs b/ylong_http_client/src/util/config/connector.rs index c61d49f..d05d04f 100644 --- a/ylong_http_client/src/util/config/connector.rs +++ b/ylong_http_client/src/util/config/connector.rs @@ -13,7 +13,7 @@ //! Connector configure module. -#[cfg(all(target_os = "linux", feature = "ylong_base", feature = "__c_openssl"))] +#[cfg(all(target_os = "linux", feature = "ylong_base", feature = "__tls"))] use super::FchownConfig; use crate::util::proxy::Proxies; @@ -21,7 +21,7 @@ use crate::util::proxy::Proxies; pub(crate) struct ConnectorConfig { pub(crate) proxies: Proxies, - #[cfg(all(target_os = "linux", feature = "ylong_base", feature = "__c_openssl"))] + #[cfg(all(target_os = "linux", feature = "ylong_base", feature = "__tls"))] pub(crate) fchown: Option, #[cfg(feature = "__tls")] diff --git a/ylong_http_client/src/util/config/http.rs b/ylong_http_client/src/util/config/http.rs index 7d0a723..08190a6 100644 --- a/ylong_http_client/src/util/config/http.rs +++ b/ylong_http_client/src/util/config/http.rs @@ -12,6 +12,8 @@ // limitations under the License. //! HTTP configure module. +#[cfg(feature = "http3")] +use crate::ErrorKind; /// Options and flags which can be used to configure `HTTP` related logic. #[derive(Clone)] @@ -20,6 +22,9 @@ pub(crate) struct HttpConfig { #[cfg(feature = "http2")] pub(crate) http2_config: http2::H2Config, + + #[cfg(feature = "http3")] + pub(crate) http3_config: http3::H3Config, } impl HttpConfig { @@ -30,6 +35,9 @@ impl HttpConfig { #[cfg(feature = "http2")] http2_config: http2::H2Config::new(), + + #[cfg(feature = "http3")] + http3_config: http3::H3Config::new(), } } } @@ -42,7 +50,7 @@ impl Default for HttpConfig { /// `HTTP` version to use. #[derive(PartialEq, Eq, Clone)] -pub(crate) enum HttpVersion { +pub enum HttpVersion { /// Enforces `HTTP/1.1` or `HTTP/1.0` requests. Http1, @@ -50,10 +58,31 @@ pub(crate) enum HttpVersion { /// Enforce `HTTP/2.0` requests without `HTTP/1.1` Upgrade or ALPN. Http2, + #[cfg(feature = "http3")] + /// Enforces `HTTP/3` requests. + Http3, + /// Negotiate the protocol version through the ALPN. Negotiate, } +#[cfg(feature = "http3")] +impl TryFrom<&[u8]> for HttpVersion { + type Error = ErrorKind; + + fn try_from(value: &[u8]) -> Result { + if value == b"h1" { + Ok(HttpVersion::Http1) + } else if value == b"h2" { + Ok(HttpVersion::Http2) + } else if value == b"h3" { + Ok(HttpVersion::Http3) + } else { + Err(ErrorKind::Other) + } + } +} + #[cfg(feature = "http2")] pub(crate) mod http2 { const DEFAULT_MAX_FRAME_SIZE: u32 = 16 * 1024; @@ -167,3 +196,88 @@ pub(crate) mod http2 { } } } + +#[cfg(feature = "http3")] +pub(crate) mod http3 { + const DEFAULT_MAX_FIELD_SECTION_SIZE: u64 = 16 * 1024; + const DEFAULT_QPACK_MAX_TABLE_CAPACITY: u64 = 16 * 1024; + const DEFAULT_QPACK_BLOCKED_STREAMS: u64 = 10; + + // todo: which settings should be pub to user + #[derive(Clone)] + pub(crate) struct H3Config { + max_field_section_size: u64, + qpack_max_table_capacity: u64, + qpack_blocked_streams: u64, + connect_protocol_enabled: Option, + additional_settings: Option>, + } + + impl H3Config { + /// `H3Config` constructor. + + pub(crate) fn new() -> Self { + Self::default() + } + + pub(crate) fn set_max_field_section_size(&mut self, size: u64) { + self.max_field_section_size = size; + } + + pub(crate) fn set_qpack_max_table_capacity(&mut self, size: u64) { + self.qpack_max_table_capacity = size; + } + + pub(crate) fn set_qpack_blocked_streams(&mut self, size: u64) { + self.qpack_blocked_streams = size; + } + + #[allow(unused)] + fn set_connect_protocol_enabled(&mut self, size: u64) { + self.connect_protocol_enabled = Some(size); + } + + #[allow(unused)] + fn insert_additional_settings(&mut self, key: u64, value: u64) { + if let Some(vec) = &mut self.additional_settings { + vec.push((key, value)); + } else { + self.additional_settings = Some(vec![(key, value)]); + } + } + + pub(crate) fn max_field_section_size(&self) -> u64 { + self.max_field_section_size + } + + pub(crate) fn qpack_max_table_capacity(&self) -> u64 { + self.qpack_max_table_capacity + } + + pub(crate) fn qpack_blocked_streams(&self) -> u64 { + self.qpack_blocked_streams + } + + #[allow(unused)] + fn connect_protocol_enabled(&mut self) -> Option { + self.connect_protocol_enabled + } + + #[allow(unused)] + fn additional_settings(&mut self) -> Option> { + self.additional_settings.clone() + } + } + + impl Default for H3Config { + fn default() -> Self { + Self { + max_field_section_size: DEFAULT_MAX_FIELD_SECTION_SIZE, + qpack_max_table_capacity: DEFAULT_QPACK_MAX_TABLE_CAPACITY, + qpack_blocked_streams: DEFAULT_QPACK_BLOCKED_STREAMS, + connect_protocol_enabled: None, + additional_settings: None, + } + } + } +} diff --git a/ylong_http_client/src/util/config/mod.rs b/ylong_http_client/src/util/config/mod.rs index efa1091..ce4c445 100644 --- a/ylong_http_client/src/util/config/mod.rs +++ b/ylong_http_client/src/util/config/mod.rs @@ -20,17 +20,19 @@ pub(crate) use client::ClientConfig; pub(crate) use connector::ConnectorConfig; #[cfg(feature = "http2")] pub(crate) use http::http2::H2Config; +#[cfg(feature = "http3")] +pub(crate) use http::http3::H3Config; pub(crate) use http::{HttpConfig, HttpVersion}; pub use settings::{Proxy, ProxyBuilder, Redirect, Retry, SpeedLimit, Timeout}; -#[cfg(feature = "__c_openssl")] +#[cfg(feature = "__tls")] pub(crate) mod tls; -#[cfg(feature = "__c_openssl")] +#[cfg(feature = "__tls")] pub(crate) use tls::{AlpnProtocol, AlpnProtocolList}; -#[cfg(feature = "__c_openssl")] +#[cfg(feature = "__tls")] pub use tls::{CertVerifier, ServerCerts}; #[cfg(feature = "tls_rust_ssl")] pub use tls::{Certificate, PrivateKey, TlsConfig, TlsConfigBuilder, TlsFileType, TlsVersion}; -#[cfg(all(target_os = "linux", feature = "ylong_base", feature = "__c_openssl"))] +#[cfg(all(target_os = "linux", feature = "ylong_base", feature = "__tls"))] mod fchown; -#[cfg(all(target_os = "linux", feature = "ylong_base", feature = "__c_openssl"))] +#[cfg(all(target_os = "linux", feature = "ylong_base", feature = "__tls"))] pub(crate) use fchown::FchownConfig; diff --git a/ylong_http_client/src/util/h2/data_ref.rs b/ylong_http_client/src/util/data_ref.rs similarity index 81% rename from ylong_http_client/src/util/h2/data_ref.rs rename to ylong_http_client/src/util/data_ref.rs index 80d108a..1d53e11 100644 --- a/ylong_http_client/src/util/h2/data_ref.rs +++ b/ylong_http_client/src/util/data_ref.rs @@ -16,8 +16,6 @@ use std::pin::Pin; use std::task::{Context, Poll}; -use ylong_http::h2::{ErrorCode, H2Error}; - use crate::runtime::{AsyncRead, ReadBuf}; use crate::util::request::RequestArc; @@ -40,20 +38,18 @@ impl BodyDataRef { &mut self, cx: &mut Context<'_>, buf: &mut [u8], - ) -> Poll> { + ) -> Poll> { let request = if let Some(ref mut request) = self.body { request } else { - return Poll::Ready(Ok(0)); + return Poll::Ready(Some(0)); }; let data = request.ref_mut().body_mut(); let mut read_buf = ReadBuf::new(buf); let data = Pin::new(data); match data.poll_read(cx, &mut read_buf) { - Poll::Ready(Err(_e)) => { - Poll::Ready(Err(H2Error::ConnectionError(ErrorCode::IntervalError))) - } - Poll::Ready(Ok(_)) => Poll::Ready(Ok(read_buf.filled().len())), + Poll::Ready(Err(_)) => Poll::Ready(None), + Poll::Ready(Ok(_)) => Poll::Ready(Some(read_buf.filled().len())), Poll::Pending => Poll::Pending, } } diff --git a/ylong_http_client/src/util/dispatcher.rs b/ylong_http_client/src/util/dispatcher.rs index 82e3d72..b897f93 100644 --- a/ylong_http_client/src/util/dispatcher.rs +++ b/ylong_http_client/src/util/dispatcher.rs @@ -17,6 +17,9 @@ pub(crate) trait Dispatcher { fn dispatch(&self) -> Option; fn is_shutdown(&self) -> bool; + + #[allow(dead_code)] + fn is_goaway(&self) -> bool; } pub(crate) enum ConnDispatcher { @@ -25,6 +28,9 @@ pub(crate) enum ConnDispatcher { #[cfg(feature = "http2")] Http2(http2::Http2Dispatcher), + + #[cfg(feature = "http3")] + Http3(http3::Http3Dispatcher), } impl Dispatcher for ConnDispatcher { @@ -37,6 +43,9 @@ impl Dispatcher for ConnDispatcher { #[cfg(feature = "http2")] Self::Http2(h2) => h2.dispatch().map(Conn::Http2), + + #[cfg(feature = "http3")] + Self::Http3(h3) => h3.dispatch().map(Conn::Http3), } } @@ -47,6 +56,22 @@ impl Dispatcher for ConnDispatcher { #[cfg(feature = "http2")] Self::Http2(h2) => h2.is_shutdown(), + + #[cfg(feature = "http3")] + Self::Http3(h3) => h3.is_shutdown(), + } + } + + fn is_goaway(&self) -> bool { + match self { + #[cfg(feature = "http1_1")] + Self::Http1(h1) => h1.is_goaway(), + + #[cfg(feature = "http2")] + Self::Http2(h2) => h2.is_goaway(), + + #[cfg(feature = "http3")] + Self::Http3(h3) => h3.is_goaway(), } } } @@ -57,6 +82,9 @@ pub(crate) enum Conn { #[cfg(feature = "http2")] Http2(http2::Http2Conn), + + #[cfg(feature = "http3")] + Http3(http3::Http3Conn), } #[cfg(feature = "http1_1")] @@ -119,6 +147,10 @@ pub(crate) mod http1 { fn is_shutdown(&self) -> bool { self.inner.shutdown.load(Ordering::Relaxed) } + + fn is_goaway(&self) -> bool { + false + } } /// Handle returned to other threads for I/O operations. @@ -380,6 +412,11 @@ pub(crate) mod http2 { fn is_shutdown(&self) -> bool { self.io_shutdown.load(Ordering::Relaxed) } + + fn is_goaway(&self) -> bool { + // todo: goaway and shutdown + false + } } impl Drop for Http2Dispatcher { @@ -660,6 +697,250 @@ pub(crate) mod http2 { } } +#[cfg(feature = "http3")] +pub(crate) mod http3 { + use std::marker::PhantomData; + use std::pin::Pin; + use std::sync::atomic::{AtomicBool, Ordering}; + use std::sync::{Arc, Mutex}; + + use ylong_http::error::HttpError; + use ylong_http::h3::{Frame, FrameDecoder, H3Error}; + + use crate::async_impl::{ConnInfo, QuicConn}; + use crate::runtime::{ + bounded_channel, unbounded_channel, AsyncRead, AsyncWrite, BoundedReceiver, BoundedSender, + UnboundedSender, + }; + use crate::util::config::H3Config; + use crate::util::data_ref::BodyDataRef; + use crate::util::dispatcher::{ConnDispatcher, Dispatcher}; + use crate::util::h3::io_manager::IOManager; + use crate::util::h3::stream_manager::StreamManager; + use crate::ErrorKind::Request; + use crate::{ErrorKind, HttpClientError}; + + pub(crate) struct Http3Dispatcher { + pub(crate) req_tx: UnboundedSender, + pub(crate) handles: Vec>, + pub(crate) _mark: PhantomData, + pub(crate) io_shutdown: Arc, + pub(crate) io_goaway: Arc, + } + + pub(crate) struct Http3Conn { + pub(crate) sender: UnboundedSender, + pub(crate) resp_receiver: BoundedReceiver, + pub(crate) resp_sender: BoundedSender, + pub(crate) io_shutdown: Arc, + pub(crate) _mark: PhantomData, + } + + pub(crate) struct RequestWrapper { + pub(crate) header: Frame, + pub(crate) data: BodyDataRef, + } + + #[derive(Debug, Clone)] + pub(crate) enum DispatchErrorKind { + H3(H3Error), + Io(std::io::ErrorKind), + Quic(quiche::Error), + ChannelClosed, + StreamFinished, + // todo: retry? + GoawayReceived, + Disconnect, + } + + pub(crate) enum RespMessage { + Output(Frame), + OutputExit(DispatchErrorKind), + } + + pub(crate) struct ReqMessage { + pub(crate) request: RequestWrapper, + pub(crate) frame_tx: BoundedSender, + } + + impl Http3Dispatcher + where + S: AsyncRead + AsyncWrite + ConnInfo + Sync + Send + Unpin + 'static, + { + pub(crate) fn new(config: H3Config, io: S, quic_connection: QuicConn) -> Self { + let (req_tx, req_rx) = unbounded_channel(); + let (io_manager_tx, io_manager_rx) = unbounded_channel(); + let (stream_manager_tx, stream_manager_rx) = unbounded_channel(); + let mut handles = Vec::with_capacity(2); + let conn = Arc::new(Mutex::new(quic_connection)); + let io_shutdown = Arc::new(AtomicBool::new(false)); + let io_goaway = Arc::new(AtomicBool::new(false)); + let mut stream_manager = StreamManager::new( + conn.clone(), + io_manager_tx, + stream_manager_rx, + req_rx, + FrameDecoder::new( + config.qpack_blocked_streams() as usize, + config.qpack_max_table_capacity() as usize, + ), + io_shutdown.clone(), + io_goaway.clone(), + ); + let stream_handle = crate::runtime::spawn(async move { + if stream_manager.init(config).is_err() { + return; + } + let _ = Pin::new(&mut stream_manager).await; + }); + handles.push(stream_handle); + + let io_handle = crate::runtime::spawn(async move { + let mut io_manager = IOManager::new(io, conn, io_manager_rx, stream_manager_tx); + let _ = Pin::new(&mut io_manager).await; + }); + handles.push(io_handle); + // read_rx gets readable stream ids and writable client channels, then read + // stream and send to the corresponding channel + Self { + req_tx, + handles, + _mark: PhantomData, + io_shutdown, + io_goaway, + } + } + } + + impl Http3Conn { + pub(crate) fn new( + sender: UnboundedSender, + io_shutdown: Arc, + ) -> Self { + const CHANNEL_SIZE: usize = 3; + let (resp_sender, resp_receiver) = bounded_channel(CHANNEL_SIZE); + Self { + sender, + resp_sender, + resp_receiver, + _mark: PhantomData, + io_shutdown, + } + } + + pub(crate) fn send_frame_to_reader( + &mut self, + request: RequestWrapper, + ) -> Result<(), HttpClientError> { + self.sender + .send(ReqMessage { + request, + frame_tx: self.resp_sender.clone(), + }) + .map_err(|_| { + HttpClientError::from_str(ErrorKind::Request, "Request Sender Closed !") + }) + } + + pub(crate) async fn recv_resp(&mut self) -> Result { + #[cfg(feature = "tokio_base")] + match self.resp_receiver.recv().await { + None => err_from_msg!(Request, "Response Receiver Closed !"), + Some(message) => match message { + RespMessage::Output(frame) => Ok(frame), + RespMessage::OutputExit(e) => Err(dispatch_client_error(e)), + }, + } + + #[cfg(feature = "ylong_base")] + match self.resp_receiver.recv().await { + Err(err) => Err(HttpClientError::from_error(ErrorKind::Request, err)), + Ok(message) => match message { + RespMessage::Output(frame) => Ok(frame), + RespMessage::OutputExit(e) => Err(dispatch_client_error(e)), + }, + } + } + } + + impl ConnDispatcher + where + S: AsyncRead + AsyncWrite + ConnInfo + Sync + Send + Unpin + 'static, + { + pub(crate) fn http3(config: H3Config, io: S, quic_connection: QuicConn) -> Self { + Self::Http3(Http3Dispatcher::new(config, io, quic_connection)) + } + } + + impl Dispatcher for Http3Dispatcher { + type Handle = Http3Conn; + + fn dispatch(&self) -> Option { + let sender = self.req_tx.clone(); + Some(Http3Conn::new(sender, self.io_shutdown.clone())) + } + + fn is_shutdown(&self) -> bool { + self.io_shutdown.load(Ordering::Relaxed) + } + + fn is_goaway(&self) -> bool { + self.io_goaway.load(Ordering::Relaxed) + } + } + + impl Drop for Http3Dispatcher { + fn drop(&mut self) { + for handle in &self.handles { + #[cfg(feature = "tokio_base")] + handle.abort(); + #[cfg(feature = "ylong_base")] + handle.cancel(); + } + } + } + + impl From for DispatchErrorKind { + fn from(value: std::io::Error) -> Self { + DispatchErrorKind::Io(value.kind()) + } + } + + impl From for DispatchErrorKind { + fn from(err: H3Error) -> Self { + DispatchErrorKind::H3(err) + } + } + + impl From for DispatchErrorKind { + fn from(value: quiche::Error) -> Self { + DispatchErrorKind::Quic(value) + } + } + + pub(crate) fn dispatch_client_error(dispatch_error: DispatchErrorKind) -> HttpClientError { + match dispatch_error { + DispatchErrorKind::H3(e) => HttpClientError::from_error(Request, HttpError::from(e)), + DispatchErrorKind::Io(e) => { + HttpClientError::from_io_error(Request, std::io::Error::from(e)) + } + DispatchErrorKind::ChannelClosed => { + HttpClientError::from_str(Request, "Coroutine channel closed.") + } + DispatchErrorKind::Quic(e) => HttpClientError::from_error(Request, e), + DispatchErrorKind::GoawayReceived => { + HttpClientError::from_str(Request, "received remote goaway.") + } + DispatchErrorKind::StreamFinished => { + HttpClientError::from_str(Request, "stream finished.") + } + DispatchErrorKind::Disconnect => { + HttpClientError::from_str(Request, "remote peer closed.") + } + } + } +} + #[cfg(test)] mod ut_dispatch { use crate::dispatcher::{ConnDispatcher, Dispatcher}; diff --git a/ylong_http_client/src/util/h2/mod.rs b/ylong_http_client/src/util/h2/mod.rs index 0916468..1c8eeb6 100644 --- a/ylong_http_client/src/util/h2/mod.rs +++ b/ylong_http_client/src/util/h2/mod.rs @@ -22,7 +22,6 @@ //! receiving of multiple streams. mod buffer; -mod data_ref; mod input; mod manager; mod output; @@ -32,7 +31,6 @@ mod streams; mod io; pub(crate) use buffer::FlowControl; -pub(crate) use data_ref::BodyDataRef; pub(crate) use input::SendData; #[cfg(feature = "ylong_base")] pub(crate) use io::{split, Reader, Writer}; diff --git a/ylong_http_client/src/util/h2/streams.rs b/ylong_http_client/src/util/h2/streams.rs index d80a340..583cbd9 100644 --- a/ylong_http_client/src/util/h2/streams.rs +++ b/ylong_http_client/src/util/h2/streams.rs @@ -20,9 +20,9 @@ use std::task::{Context, Poll}; use ylong_http::h2::{Data, ErrorCode, Frame, FrameFlags, H2Error, Payload}; use crate::runtime::UnboundedSender; +use crate::util::data_ref::BodyDataRef; use crate::util::dispatcher::http2::DispatchErrorKind; use crate::util::h2::buffer::{FlowControl, RecvWindow, SendWindow}; -use crate::util::h2::data_ref::BodyDataRef; pub(crate) const INITIAL_MAX_SEND_STREAM_ID: u32 = u32::MAX >> 1; pub(crate) const INITIAL_MAX_RECV_STREAM_ID: u32 = u32::MAX >> 1; @@ -461,8 +461,8 @@ impl Streams { } else { return Err(H2Error::ConnectionError(ErrorCode::IntervalError)); }; - match stream.data.poll_read(cx, buf)? { - Poll::Ready(size) => { + match stream.data.poll_read(cx, buf) { + Poll::Ready(Some(size)) => { if size > 0 { stream.send_window.send_data(size as u32); self.flow_control.send_data(size as u32); @@ -485,6 +485,7 @@ impl Streams { ))) } } + Poll::Ready(None) => Err(H2Error::ConnectionError(ErrorCode::IntervalError)), Poll::Pending => { self.push_back_pending_send(id); Ok(DataReadState::Pending) diff --git a/ylong_http_client/src/util/h3/io_manager.rs b/ylong_http_client/src/util/h3/io_manager.rs new file mode 100644 index 0000000..f71b122 --- /dev/null +++ b/ylong_http_client/src/util/h3/io_manager.rs @@ -0,0 +1,225 @@ +// Copyright (c) 2023 Huawei Device Co., Ltd. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::future::Future; +use std::pin::Pin; +use std::sync::{Arc, Mutex}; +use std::task::{Context, Poll}; + +use ylong_runtime::time::{sleep, Sleep}; + +use crate::async_impl::{ConnInfo, QuicConn}; +use crate::runtime::{AsyncRead, AsyncWrite, ReadBuf, UnboundedReceiver, UnboundedSender}; +use crate::util::dispatcher::http3::DispatchErrorKind; +use crate::util::h3::stream_manager::UPD_RECV_BUF_SIZE; + +const UDP_SEND_BUF_SIZE: usize = 1350; + +enum IOManagerState { + IORecving, + Timeout, + IOSending, + ChannelRecving, +} + +pub(crate) struct IOManager { + io: S, + conn: Arc>, + io_manager_rx: UnboundedReceiver>, + stream_manager_tx: UnboundedSender>, + recv_timeout: Option>>, + state: IOManagerState, + recv_buf: [u8; UPD_RECV_BUF_SIZE], + send_data: SendData, +} + +impl IOManager { + pub(crate) fn new( + io: S, + conn: Arc>, + io_manager_rx: UnboundedReceiver>, + stream_manager_tx: UnboundedSender>, + ) -> Self { + Self { + io, + conn, + io_manager_rx, + stream_manager_tx, + recv_timeout: None, + state: IOManagerState::IORecving, + recv_buf: [0u8; UPD_RECV_BUF_SIZE], + send_data: SendData::new(), + } + } + fn poll_recv_signal( + &mut self, + cx: &mut Context<'_>, + ) -> Poll, DispatchErrorKind>> { + #[cfg(feature = "tokio_base")] + match self.io_manager_rx.poll_recv(cx) { + Poll::Ready(None) => Poll::Ready(Err(DispatchErrorKind::ChannelClosed)), + Poll::Ready(Some(data)) => Poll::Ready(Ok(data)), + Poll::Pending => Poll::Pending, + } + #[cfg(feature = "ylong_base")] + match self.io_manager_rx.poll_recv(cx) { + Poll::Ready(Err(_e)) => Poll::Ready(Err(DispatchErrorKind::ChannelClosed)), + Poll::Ready(Ok(data)) => Poll::Ready(Ok(data)), + Poll::Pending => Poll::Pending, + } + } + + fn poll_io_recv(&mut self, cx: &mut Context<'_>) -> Poll> { + let mut buf = ReadBuf::new(&mut self.recv_buf); + if self.recv_timeout.is_none() { + if let Some(time) = self.conn.lock().unwrap().timeout() { + self.recv_timeout = Some(Box::pin(sleep(time))); + }; + } + + if let Some(delay) = self.recv_timeout.as_mut() { + if let Poll::Ready(()) = delay.as_mut().poll(cx) { + self.recv_timeout = None; + self.conn.lock().unwrap().on_timeout(); + self.state = IOManagerState::Timeout; + return Poll::Ready(Ok(())); + } + } + match Pin::new(&mut self.io).poll_read(cx, &mut buf) { + Poll::Ready(Ok(())) => { + let info = self.io.conn_detail(); + self.recv_timeout = None; + let recv_info = quiche::RecvInfo { + to: info.local, + from: info.peer, + }; + return match self.conn.lock().unwrap().recv(buf.filled_mut(), recv_info) { + Ok(_) => { + let _ = self.stream_manager_tx.send(Ok(())); + // io recv once again + Poll::Ready(Ok(())) + } + Err(e) => Poll::Ready(Err(DispatchErrorKind::Quic(e))), + }; + } + Poll::Ready(Err(e)) => Poll::Ready(Err(DispatchErrorKind::Io(e.kind()))), + Poll::Pending => { + self.state = IOManagerState::IOSending; + Poll::Pending + } + } + } + + fn poll_io_send(&mut self, cx: &mut Context<'_>) -> Poll> { + loop { + // UDP buf has not been sent to the peer, send rest UDP buf first + if self.send_data.buf_size == self.send_data.offset { + // Retrieve the data to be sent via UDP from the connection + let size = match self.conn.lock().unwrap().send(&mut self.send_data.buf) { + Ok((size, _)) => size, + Err(quiche::Error::Done) => { + self.state = IOManagerState::ChannelRecving; + return Poll::Ready(Ok(())); + } + Err(e) => { + return Poll::Ready(Err(DispatchErrorKind::Quic(e))); + } + }; + self.send_data.buf_size = size; + self.send_data.offset = 0; + } + + match Pin::new(&mut self.io).poll_write( + cx, + &self.send_data.buf[self.send_data.offset..self.send_data.buf_size], + ) { + Poll::Ready(Ok(size)) => { + self.send_data.offset += size; + if self.send_data.offset != self.send_data.buf_size { + // loop to send UDP buf + continue; + } else { + self.send_data.offset = 0; + self.send_data.buf_size = 0; + } + } + Poll::Ready(Err(e)) => { + return Poll::Ready(Err(DispatchErrorKind::Io(e.kind()))); + } + Poll::Pending => { + self.state = IOManagerState::ChannelRecving; + return Poll::Pending; + } + } + } + } +} + +impl Future for IOManager { + type Output = Result<(), DispatchErrorKind>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); + this.state = IOManagerState::IORecving; + loop { + match this.state { + IOManagerState::IORecving => { + if let Poll::Ready(Err(e)) = this.poll_io_recv(cx) { + return Poll::Ready(Err(e)); + } + } + IOManagerState::IOSending => { + if let Poll::Ready(Err(e)) = this.poll_io_send(cx) { + return Poll::Ready(Err(e)); + } + } + IOManagerState::Timeout => { + if let Poll::Ready(Err(e)) = this.poll_io_send(cx) { + return Poll::Ready(Err(e)); + } + // ensure pending at io recv + this.state = IOManagerState::IORecving; + } + IOManagerState::ChannelRecving => match this.poll_recv_signal(cx) { + // won't recv Err now + Poll::Ready(Ok(_)) => { + continue; + } + Poll::Ready(Err(e)) => { + return Poll::Ready(Err(e)); + } + Poll::Pending => { + this.state = IOManagerState::IORecving; + return Poll::Pending; + } + }, + } + } + } +} + +pub(crate) struct SendData { + pub(crate) buf: [u8; UDP_SEND_BUF_SIZE], + pub(crate) buf_size: usize, + pub(crate) offset: usize, +} + +impl SendData { + pub(crate) fn new() -> Self { + Self { + buf: [0u8; UDP_SEND_BUF_SIZE], + buf_size: 0, + offset: 0, + } + } +} diff --git a/ylong_http/src/h3/connection.rs b/ylong_http_client/src/util/h3/mod.rs similarity index 84% rename from ylong_http/src/h3/connection.rs rename to ylong_http_client/src/util/h3/mod.rs index 0d03a7e..ec5afa1 100644 --- a/ylong_http/src/h3/connection.rs +++ b/ylong_http_client/src/util/h3/mod.rs @@ -10,3 +10,9 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + +//! Http3 Protocol module. + +pub(crate) mod io_manager; +pub(crate) mod stream_manager; +pub(crate) mod streams; diff --git a/ylong_http_client/src/util/h3/stream_manager.rs b/ylong_http_client/src/util/h3/stream_manager.rs new file mode 100644 index 0000000..2834412 --- /dev/null +++ b/ylong_http_client/src/util/h3/stream_manager.rs @@ -0,0 +1,774 @@ +// Copyright (c) 2023 Huawei Device Co., Ltd. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Stream Manager module. + +use std::future::Future; +use std::pin::Pin; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::{Arc, Mutex}; +use std::task::{Context, Poll}; + +use quiche::Shutdown; +use ylong_http::h3::{ + Frame, FrameDecoder, FrameEncoder, FrameKind, Frames, H3Error, H3ErrorCode, Headers, Payload, + Settings, StreamMessage, CONTROL_STREAM_TYPE, QPACK_DECODER_STREAM_TYPE, + QPACK_ENCODER_STREAM_TYPE, SETTINGS_FRAME_TYPE, +}; + +use crate::async_impl::QuicConn; +use crate::runtime::{UnboundedReceiver, UnboundedSender}; +use crate::util::config::H3Config; +use crate::util::dispatcher::http3::{DispatchErrorKind, ReqMessage, RespMessage}; +use crate::util::h3::streams::{DataReadState, QUICStreamType, Streams}; + +pub(crate) const UPD_RECV_BUF_SIZE: usize = 65535; +const DECODE_BUF_SIZE: usize = 1024; + +pub(crate) struct StreamManager { + pub(crate) streams: Streams, + pub(crate) quic_conn: Arc>, + pub(crate) io_manager_tx: UnboundedSender>, + pub(crate) stream_manager_rx: UnboundedReceiver>, + pub(crate) req_rx: UnboundedReceiver, + pub(crate) stream_recv_buf: [u8; UPD_RECV_BUF_SIZE], + pub(crate) encoder: FrameEncoder, + pub(crate) decoder: FrameDecoder, + pub(crate) encoder_buf: [u8; DECODE_BUF_SIZE], + pub(crate) inst_buf: [u8; DECODE_BUF_SIZE], + pub(crate) peer_settings: Option, + pub(crate) io_shutdown: Arc, + pub(crate) io_goaway: Arc, +} + +impl StreamManager { + pub(crate) fn new( + quic_conn: Arc>, + io_manager_tx: UnboundedSender>, + stream_manager_rx: UnboundedReceiver>, + req_rx: UnboundedReceiver, + decoder: FrameDecoder, + io_shutdown: Arc, + io_goaway: Arc, + ) -> Self { + Self { + streams: Streams::new(), + quic_conn, + io_manager_tx, + stream_manager_rx, + req_rx, + stream_recv_buf: [0u8; UPD_RECV_BUF_SIZE], + encoder_buf: [0u8; DECODE_BUF_SIZE], + inst_buf: [0u8; DECODE_BUF_SIZE], + encoder: FrameEncoder::default(), + decoder, + peer_settings: None, + io_shutdown, + io_goaway, + } + } + + fn poll_recv_signal( + &mut self, + cx: &mut Context<'_>, + ) -> Poll, DispatchErrorKind>> { + #[cfg(feature = "tokio_base")] + match self.stream_manager_rx.poll_recv(cx) { + Poll::Ready(None) => Poll::Ready(Err(DispatchErrorKind::ChannelClosed)), + Poll::Ready(Some(data)) => Poll::Ready(Ok(data)), + Poll::Pending => Poll::Pending, + } + #[cfg(feature = "ylong_base")] + match self.stream_manager_rx.poll_recv(cx) { + Poll::Ready(Err(_e)) => Poll::Ready(Err(DispatchErrorKind::ChannelClosed)), + Poll::Ready(Ok(data)) => Poll::Ready(Ok(data)), + Poll::Pending => Poll::Pending, + } + } + + fn poll_recv_request( + &mut self, + cx: &mut Context<'_>, + ) -> Poll> { + #[cfg(feature = "tokio_base")] + match self.req_rx.poll_recv(cx) { + Poll::Ready(None) => Poll::Ready(Err(DispatchErrorKind::ChannelClosed)), + Poll::Ready(Some(data)) => Poll::Ready(Ok(data)), + Poll::Pending => Poll::Pending, + } + #[cfg(feature = "ylong_base")] + match self.req_rx.poll_recv(cx) { + Poll::Ready(Err(_e)) => Poll::Ready(Err(DispatchErrorKind::ChannelClosed)), + Poll::Ready(Ok(data)) => Poll::Ready(Ok(data)), + Poll::Pending => Poll::Pending, + } + } + + fn send_inst_to_peer( + &mut self, + headers: &Headers, + quic_conn: &mut QuicConn, + ) -> Result<(), DispatchErrorKind> { + if let Some(vec) = headers.get_instruction() { + let qpack_decode_stream_id = + self.streams + .qpack_decode_stream_id() + .ok_or(DispatchErrorKind::H3(H3Error::Connection( + H3ErrorCode::H3InternalError, + )))?; + quic_conn.stream_send(qpack_decode_stream_id, vec, false)?; + } + Ok(()) + } + + fn transmit_error( + &mut self, + cx: &mut Context<'_>, + id: u64, + error: DispatchErrorKind, + ) -> Result<(), DispatchErrorKind> { + self.streams.send_error(cx, id, error) + } + + fn poll_input_request(&mut self, cx: &mut Context<'_>) -> Result<(), DispatchErrorKind> { + self.streams.try_consume_pending_concurrency(); + let len = self.streams.pending_stream_len(); + // Some streams may be blocked due to the server not reading the message. Avoid + // reading these streams twice in one loop + for _ in 0..len { + if let Some(id) = self.streams.next_stream() { + self.input_stream_frame(cx, id)?; + } else { + break; + } + } + Ok(()) + } + + fn input_stream_frame( + &mut self, + cx: &mut Context<'_>, + id: u64, + ) -> Result<(), DispatchErrorKind> { + if let Some(header) = self.streams.get_header(id)? { + self.poll_send_header(id, header)?; + } + + // encoding means last frame is still encoding, can not create new frame before + // consumed. + if self.streams.encoding(id)? { + if let Err(e) = self.poll_send_frame(id, None) { + return match e { + DispatchErrorKind::Quic(quiche::Error::StreamStopped(_)) => Ok(()), + e => Err(e), + }; + } + if self.streams.encoding(id)? { + self.streams.push_back_pending_send(id); + return Ok(()); + } + } + + loop { + match self.poll_read_body(cx, id)? { + DataReadState::Closed | DataReadState::Pending => { + break; + } + DataReadState::Ready(data) => { + if let Err(e) = self.poll_send_frame(id, Some(*data)) { + return match e { + DispatchErrorKind::Quic(quiche::Error::StreamStopped(_)) => Ok(()), + e => Err(e), + }; + } + if self.streams.encoding(id)? { + self.streams.push_back_pending_send(id); + break; + } + } + DataReadState::Finish => { + let mut quic_conn = self.quic_conn.lock().unwrap(); + quic_conn.stream_send(id, b"", true)?; + let _ = self.io_manager_tx.send(Ok(())); + break; + } + } + } + Ok(()) + } + + fn poll_send_header(&mut self, id: u64, frame: Frame) -> Result<(), DispatchErrorKind> { + self.streams.set_encoding(id, true)?; + self.encoder.set_frame(id, frame)?; + let quic_conn = self.quic_conn.clone(); + let qpack_encode_stream_id = + self.streams + .qpack_encode_stream_id() + .ok_or(DispatchErrorKind::H3(H3Error::Connection( + H3ErrorCode::H3InternalError, + )))?; + let mut quic_conn = quic_conn.lock().unwrap(); + + // invalid means stream has not been created, create it first + if let Err(quiche::Error::InvalidStreamState(_)) = + quic_conn.stream_writable(id, DECODE_BUF_SIZE) + { + quic_conn.stream_send(id, b"", false)?; + } + while quic_conn.stream_writable(id, DECODE_BUF_SIZE)? + && quic_conn.stream_writable(qpack_encode_stream_id, DECODE_BUF_SIZE)? + { + let (data_size, inst_size) = + self.encoder + .encode(id, &mut self.encoder_buf, &mut self.inst_buf)?; + if inst_size != 0 { + quic_conn.stream_send( + qpack_encode_stream_id, + &self.inst_buf[..inst_size], + false, + )?; + } + if data_size != 0 { + quic_conn.stream_send(id, &self.encoder_buf[..data_size], false)?; + } + if inst_size == 0 && data_size == 0 { + self.streams.set_encoding(id, false)?; + break; + } + } + + let _ = self.io_manager_tx.send(Ok(())); + Ok(()) + } + + fn poll_send_frame(&mut self, id: u64, frame: Option) -> Result<(), DispatchErrorKind> { + if let Some(frame) = frame { + self.streams.set_encoding(id, true)?; + self.encoder.set_frame(id, frame)?; + } + let mut quic_conn = self.quic_conn.lock().unwrap(); + + loop { + if !quic_conn.stream_writable(id, DECODE_BUF_SIZE)? { + break; + } + let (data_size, _) = + self.encoder + .encode(id, &mut self.encoder_buf, &mut self.inst_buf)?; + if data_size != 0 { + quic_conn.stream_send(id, &self.encoder_buf[..data_size], false)?; + let _ = self.io_manager_tx.send(Ok(())); + } else { + self.streams.set_encoding(id, false)?; + break; + } + } + Ok(()) + } + + pub(crate) fn poll_read_body( + &mut self, + cx: &mut Context<'_>, + id: u64, + ) -> Result { + const DEFAULT_MAX_FRAME_SIZE: usize = 16 * 1024; + let len = std::cmp::min( + self.quic_conn + .lock() + .unwrap() + .stream_capacity(id) + .map_err(|_| { + DispatchErrorKind::H3(H3Error::Stream(id, H3ErrorCode::H3InternalError)) + })?, + DEFAULT_MAX_FRAME_SIZE, + ); + let mut buf = [0u8; DEFAULT_MAX_FRAME_SIZE]; + self.streams.poll_sized_data(cx, id, &mut buf[..len]) + } + + pub(crate) fn init(&mut self, config: H3Config) -> Result<(), DispatchErrorKind> { + self.decoder + .local_allowed_max_field_section_size(config.max_field_section_size() as usize); + self.send_settings(config)?; + self.open_uni_stream(QPACK_ENCODER_STREAM_TYPE)?; + self.open_uni_stream(QPACK_DECODER_STREAM_TYPE)?; + Ok(()) + } + + pub(crate) fn open_uni_stream(&mut self, stream_type: u8) -> Result { + let buf = [stream_type]; + let id = match stream_type { + CONTROL_STREAM_TYPE => { + self.streams + .control_stream_id() + .ok_or(DispatchErrorKind::H3(H3Error::Connection( + H3ErrorCode::H3InternalError, + )))? + } + QPACK_ENCODER_STREAM_TYPE => { + self.streams + .qpack_encode_stream_id() + .ok_or(DispatchErrorKind::H3(H3Error::Connection( + H3ErrorCode::H3InternalError, + )))? + } + QPACK_DECODER_STREAM_TYPE => { + self.streams + .qpack_decode_stream_id() + .ok_or(DispatchErrorKind::H3(H3Error::Connection( + H3ErrorCode::H3InternalError, + )))? + } + _ => { + return Err(DispatchErrorKind::H3(H3Error::Connection( + H3ErrorCode::H3InternalError, + ))) + } + }; + let mut quic_conn = self.quic_conn.lock().unwrap(); + + quic_conn.stream_send(id, &buf, false)?; + let _ = quic_conn.stream_priority(id, 0, false); + Ok(id) + } + + pub(crate) fn send_settings(&mut self, config: H3Config) -> Result<(), DispatchErrorKind> { + let control_stream_id = self.open_uni_stream(CONTROL_STREAM_TYPE)?; + + let mut settings = Settings::default(); + settings.set_max_field_section_size(config.max_field_section_size()); + settings.set_qpack_max_table_capacity(config.qpack_max_table_capacity()); + settings.set_qpack_block_stream(config.qpack_blocked_streams()); + + let mut quic_conn = self.quic_conn.lock().unwrap(); + let settings = Frame::new(SETTINGS_FRAME_TYPE, Payload::Settings(settings)); + self.encoder.set_frame(control_stream_id, settings)?; + loop { + let (size, _) = self.encoder.encode( + control_stream_id, + &mut self.encoder_buf, + &mut self.inst_buf, + )?; + if size == 0 { + return Ok(()); + } + quic_conn.stream_send(control_stream_id, &self.encoder_buf[..size], false)?; + } + } + + pub(crate) fn poll_stream_recv( + &mut self, + cx: &mut Context<'_>, + ) -> Result<(), DispatchErrorKind> { + let mut need_send = false; + let lock = self.quic_conn.clone(); + let mut quic_conn = lock.lock().unwrap(); + + if let Some(stream_id) = self.streams.peer_control_stream_id() { + need_send |= self.try_recv_uni_stream(cx, &mut quic_conn, stream_id)?; + }; + if let Some(stream_id) = self.streams.peer_qpack_encode_stream_id() { + need_send |= self.try_recv_uni_stream(cx, &mut quic_conn, stream_id)?; + }; + if let Some(stream_id) = self.streams.peer_qpack_decode_stream_id() { + need_send |= self.try_recv_uni_stream(cx, &mut quic_conn, stream_id)?; + }; + for id in quic_conn.readable() { + if !self.streams.frame_acceptable(id) { + continue; + } + need_send |= self.read_stream(cx, &mut quic_conn, id)?; + } + + if quic_conn.is_closed() { + self.shutdown(cx, &DispatchErrorKind::Disconnect); + } + + if need_send { + let _ = self.io_manager_tx.send(Ok(())); + } + Ok(()) + } + + fn try_recv_uni_stream( + &mut self, + cx: &mut Context<'_>, + quic_conn: &mut QuicConn, + stream_id: u64, + ) -> Result { + if quic_conn.stream_finished(stream_id) { + return Err(DispatchErrorKind::H3(H3Error::Connection( + H3ErrorCode::H3ClosedCriticalStream, + ))); + } + + match self.read_stream(cx, quic_conn, stream_id) { + Ok(need_send) => { + if quic_conn.stream_finished(stream_id) { + return Err(DispatchErrorKind::H3(H3Error::Connection( + H3ErrorCode::H3ClosedCriticalStream, + ))); + } + Ok(need_send) + } + Err(e) => Err(e), + } + } + + fn read_stream( + &mut self, + cx: &mut Context<'_>, + quic_conn: &mut QuicConn, + id: u64, + ) -> Result { + if QUICStreamType::from(id) == QUICStreamType::ServerInitialBidirectional { + return Err(DispatchErrorKind::H3(H3Error::Connection( + H3ErrorCode::H3StreamCreationError, + ))); + } + let mut need_send = false; + loop { + let (size, fin) = match quic_conn.stream_recv(id, &mut self.stream_recv_buf) { + Ok((size, fin)) => { + need_send = true; + (size, fin) + } + Err(quiche::Error::Done) => { + return Ok(need_send); + } + Err(quiche::Error::StreamStopped(err)) | Err(quiche::Error::StreamReset(err)) => { + if err != H3ErrorCode::H3NoError as u64 { + return Err(DispatchErrorKind::H3(H3Error::Stream(id, err.into()))); + } else { + return Ok(false); + } + } + Err(e) => { + return Err(DispatchErrorKind::Quic(e)); + } + }; + self.process_recv_data(cx, id, size, quic_conn)?; + if fin { + self.finish_stream(cx, id)?; + return Ok(true); + } + } + } + + fn process_recv_data( + &mut self, + cx: &mut Context<'_>, + id: u64, + size: usize, + quic_conn: &mut QuicConn, + ) -> Result<(), DispatchErrorKind> { + let mut stream_id = id; + let mut size = size; + loop { + match self + .decoder + .decode(stream_id, &self.stream_recv_buf[..size]) + { + Ok(StreamMessage::Request(frames)) => { + self.recv_request_stream(cx, stream_id, frames, quic_conn)?; + } + Ok(StreamMessage::Push(_id, _frames)) => { + // MAX_PUSH_ID not send, Push Stream means error + return Err(DispatchErrorKind::H3(H3Error::Connection( + H3ErrorCode::H3IdError, + ))); + } + Ok(StreamMessage::QpackDecoder(order)) => { + self.recv_qpack_decode_stream(stream_id, order)?; + } + Ok(StreamMessage::Control(frames)) => { + self.recv_control_stream(cx, stream_id, frames)?; + } + Ok(StreamMessage::WaitingMore) | Ok(StreamMessage::Unknown) => {} + Ok(StreamMessage::QpackEncoder(vec)) => { + self.recv_qpack_encode_stream(stream_id, vec)?; + } + Err(e) => { + self.transmit_error(cx, stream_id, DispatchErrorKind::H3(e))?; + } + } + if let Some(id) = self.streams.get_resume_stream_id() { + stream_id = id; + } else { + return Ok(()); + }; + size = 0; + } + } + + fn recv_qpack_encode_stream( + &mut self, + stream_id: u64, + vec: Vec, + ) -> Result<(), DispatchErrorKind> { + self.streams.set_peer_qpack_encode_stream_id(stream_id)?; + for resume_id in vec { + self.streams.resume_stream_recv(resume_id); + } + Ok(()) + } + + fn recv_request_stream( + &mut self, + cx: &mut Context<'_>, + id: u64, + frames: Frames, + quic_conn: &mut QuicConn, + ) -> Result<(), DispatchErrorKind> { + for kind in frames.into_iter() { + let frame = match kind { + FrameKind::Complete(frame) => frame, + FrameKind::Blocked => { + self.streams.pend_stream_recv(id); + return Ok(()); + } + FrameKind::Partial => return Ok(()), + }; + match frame.payload() { + Payload::Headers(headers) => { + self.send_inst_to_peer(headers, quic_conn)?; + self.streams.send_frame(cx, id, *frame)?; + } + Payload::Data(_) => { + self.streams.send_frame(cx, id, *frame)?; + } + Payload::PushPromise(_) => { + return Err(DispatchErrorKind::H3(H3Error::Connection( + H3ErrorCode::H3IdError, + ))) + } + _ => { + return Err(DispatchErrorKind::H3(H3Error::Connection( + H3ErrorCode::H3FrameUnexpected, + ))) + } + } + } + Ok(()) + } + + fn recv_control_stream( + &mut self, + cx: &mut Context<'_>, + id: u64, + frames: Frames, + ) -> Result<(), DispatchErrorKind> { + let mut is_first_frame = if let Some(stream_id) = self.streams.peer_control_stream_id() { + if stream_id != id { + return Err(DispatchErrorKind::H3(H3Error::Connection( + H3ErrorCode::H3StreamCreationError, + ))); + } + false + } else { + self.streams.set_peer_control_stream_id(id)?; + true + }; + for frame in frames.iter() { + let FrameKind::Complete(frame) = frame else { + continue; + }; + match frame.payload() { + Payload::Settings(settings) => { + self.recv_setting_frame(settings)?; + is_first_frame = false; + } + Payload::Goaway(goaway) => { + self.recv_goaway_frame(cx, *goaway.get_id())?; + } + Payload::CancelPush(_cancel) => { + return Err(DispatchErrorKind::H3(H3Error::Connection( + H3ErrorCode::H3IdError, + ))); + } + _ => { + return Err(DispatchErrorKind::H3(H3Error::Connection( + H3ErrorCode::H3FrameUnexpected, + ))); + } + } + if is_first_frame { + return Err(DispatchErrorKind::H3(H3Error::Connection( + H3ErrorCode::H3MissingSettings, + ))); + } + } + Ok(()) + } + + fn recv_qpack_decode_stream( + &mut self, + stream_id: u64, + order: Vec, + ) -> Result<(), DispatchErrorKind> { + self.streams.set_peer_qpack_decode_stream_id(stream_id)?; + self.encoder.decode_remote_inst(&order)?; + Ok(()) + } + + fn recv_setting_frame(&mut self, settings: &Settings) -> Result<(), DispatchErrorKind> { + if self.peer_settings.is_some() { + return Err(DispatchErrorKind::H3(H3Error::Connection( + H3ErrorCode::H3FrameUnexpected, + ))); + } + self.peer_settings = Some(settings.clone()); + if let Some(value) = settings.qpack_max_table_capacity() { + self.encoder.set_max_table_capacity(value as usize)?; + } + if let Some(value) = settings.qpack_block_stream() { + self.encoder.set_max_blocked_stream_size(value as usize); + } + Ok(()) + } + + fn recv_goaway_frame( + &mut self, + cx: &mut Context<'_>, + goaway_id: u64, + ) -> Result<(), DispatchErrorKind> { + self.io_goaway.store(true, Ordering::Relaxed); + self.req_rx.close(); + self.streams.goaway(cx, goaway_id)?; + Ok(()) + } + + fn handle_error(&mut self, cx: &mut Context<'_>, err: &DispatchErrorKind) -> bool { + match err { + DispatchErrorKind::H3(H3Error::Stream(id, e)) => { + self.handle_stream_error(cx, *id, e); + false + } + DispatchErrorKind::Quic(quiche::Error::InvalidStreamState(id)) => { + self.handle_stream_error(cx, *id, &H3ErrorCode::H3NoError); + false + } + err => { + self.handle_connection_error(cx, err); + true + } + } + } + + fn handle_stream_error(&mut self, cx: &mut Context<'_>, id: u64, err: &H3ErrorCode) { + let _ = self + .quic_conn + .lock() + .unwrap() + .stream_shutdown(id, Shutdown::Read, *err as u64); + self.streams.shutdown_stream(cx, id, err); + } + + fn handle_connection_error(&mut self, cx: &mut Context<'_>, err: &DispatchErrorKind) { + self.shutdown(cx, err); + let err = match err { + DispatchErrorKind::H3(H3Error::Connection(err)) => *err, + _ => H3ErrorCode::H3InternalError, + }; + let _ = self.quic_conn.lock().unwrap().close(true, err as u64, b""); + let _ = self.io_manager_tx.send(Ok(())); + self.req_rx.close(); + } + + fn shutdown(&mut self, cx: &mut Context<'_>, err: &DispatchErrorKind) { + self.io_shutdown.store(true, Ordering::Relaxed); + self.streams.shutdown(cx, err); + } + + fn finish_stream(&mut self, cx: &mut Context<'_>, id: u64) -> Result<(), DispatchErrorKind> { + self.streams.finish_stream(cx, id)?; + self.encoder.finish_stream(id)?; + self.decoder.finish_stream(id)?; + if self.streams.goaway_id().is_some() && self.streams.current_concurrency() == 0 { + self.io_shutdown.store(true, Ordering::Relaxed); + } + Ok(()) + } + + pub(crate) fn poll_blocked_message( + &mut self, + cx: &mut Context<'_>, + ) -> Poll> { + self.streams.poll_blocked_message(cx) + } +} + +impl Future for StreamManager { + type Output = Result<(), DispatchErrorKind>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); + loop { + // 1 recv stream_manager_rx, meaning data to send/recv + match this.poll_recv_signal(cx) { + // consume all the signals + Poll::Ready(Ok(Ok(()))) => continue, + Poll::Ready(Ok(Err(e))) | Poll::Ready(Err(e)) => { + if this.handle_error(cx, &e) { + return Poll::Ready(Err(e)); + } + } + Poll::Pending => {} + } + + // 2 check id's channel sendable / control/qpack/push, decode, send or cache + // frame if stream_recv, send io_manager_tx to io manager + if let Err(e) = this.poll_stream_recv(cx) { + if this.handle_error(cx, &e) { + return Poll::Ready(Err(e)); + } + } + + if let Poll::Ready(Err(e)) = this.poll_blocked_message(cx) { + if this.handle_error(cx, &e) { + return Poll::Ready(Err(e)); + } + } + + // 3 recv req_rx, check concurrency + loop { + let req = match this.poll_recv_request(cx) { + Poll::Ready(Ok(req)) => req, + Poll::Ready(Err(e)) => { + if this.handle_error(cx, &e) { + return Poll::Ready(Err(e)); + } + break; + } + Poll::Pending => break, + }; + if let Err(e) = this.streams.new_unidirectional_stream( + req.request.header, + req.request.data, + req.frame_tx.clone(), + ) { + let _ = req.frame_tx.try_send(RespMessage::OutputExit(e)); + } + } + + // 4 in concurrency stream, set frame to encoder, set encoding flag, get encode + // result send flag to io manager + if let Err(e) = this.poll_input_request(cx) { + if this.handle_error(cx, &e) { + return Poll::Ready(Err(e)); + } + } + return Poll::Pending; + } + } +} diff --git a/ylong_http_client/src/util/h3/streams.rs b/ylong_http_client/src/util/h3/streams.rs new file mode 100644 index 0000000..6345d67 --- /dev/null +++ b/ylong_http_client/src/util/h3/streams.rs @@ -0,0 +1,708 @@ +// Copyright (c) 2023 Huawei Device Co., Ltd. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::{HashMap, HashSet, VecDeque}; +use std::future::Future; +use std::pin::Pin; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::task::{Context, Poll}; + +use ylong_http::h3::{Data, Frame, H3Error, H3ErrorCode, Payload, DATA_FRAME_TYPE}; + +use crate::runtime::{BoundedSender, SendError}; +use crate::util::data_ref::BodyDataRef; +use crate::util::dispatcher::http3::{DispatchErrorKind, RespMessage}; + +pub(crate) type OutputSendFut = + Pin>> + Send + Sync>>; + +const HTTP3_FIRST_BIDI_STREAM_ID: u64 = 0u64; +const HTTP3_FIRST_UNI_STREAM_ID: u64 = 2u64; +const HTTP3_MAX_STREAM_ID: u64 = (1u64 << 62) - 1; +const DEFAULT_MAX_CONCURRENT_STREAMS: u32 = 100; + +#[derive(PartialEq, Clone)] +pub(crate) enum H3StreamState { + Sending, + HeadersReceived, + BodyReceived, + TrailerReceived, + Shutdown, +} + +#[derive(PartialEq, Clone)] +pub(crate) enum QUICStreamType { + ClientInitialBidirectional, + ServerInitialBidirectional, + ClientInitialUnidirectional, + ServerInitialUnidirectional, +} + +impl QUICStreamType { + pub(crate) fn from(id: u64) -> Self { + match id % 4 { + 0 => QUICStreamType::ClientInitialBidirectional, + 1 => QUICStreamType::ServerInitialBidirectional, + 2 => QUICStreamType::ClientInitialUnidirectional, + _ => QUICStreamType::ServerInitialUnidirectional, + } + } +} + +// Unidirectional Streams +pub(crate) struct BidirectionalStream { + pub(crate) state: H3StreamState, + pub(crate) frame_tx: BoundedSender, + pub(crate) header: Option, + pub(crate) data: BodyDataRef, + pub(crate) pending_message: VecDeque, + pub(crate) encoding: bool, + pub(crate) curr_message: Option, +} + +impl BidirectionalStream { + fn new(frame_tx: BoundedSender, header: Frame, data: BodyDataRef) -> Self { + Self { + state: H3StreamState::Sending, + frame_tx, + header: Some(header), + data, + pending_message: VecDeque::new(), + encoding: false, + curr_message: None, + } + } + + fn transmit_message( + &mut self, + cx: &mut Context<'_>, + message: RespMessage, + ) -> Poll> { + let mut task = { + let sender = self.frame_tx.clone(); + let ft = async move { sender.send(message).await }; + Box::pin(ft) + }; + + match task.as_mut().poll(cx) { + Poll::Ready(Ok(_)) => Poll::Ready(Ok(())), + // The current coroutine sending the request exited prematurely. + Poll::Ready(Err(_)) => Poll::Ready(Err(DispatchErrorKind::ChannelClosed)), + Poll::Pending => { + self.curr_message = Some(task); + Poll::Pending + } + } + } +} + +pub(crate) struct Streams { + bidirectional_stream: HashMap, + control_stream_id: Option, + peer_control_stream_id: Option, + qpack_encode_stream_id: Option, + qpack_decode_stream_id: Option, + peer_qpack_encode_stream_id: Option, + peer_qpack_decode_stream_id: Option, + // unused now + goaway_id: Option, + peer_goaway_id: Option, + // meet the sending conditions, waiting for sending + pending_send: VecDeque, + // cannot recv cause of stream blocks + pending_recv: HashSet, + // stream resumes and should decode again + resume_recv: VecDeque, + // too many working streams, pending for concurrency + pending_concurrency: VecDeque, + // cannot recv cause of channel blocked + pending_channel: HashSet, + working_stream_num: u32, + max_stream_concurrency: u32, + next_uni_stream_id: AtomicU64, + next_bidi_stream_id: AtomicU64, +} + +impl Streams { + pub(crate) fn new() -> Self { + Self { + bidirectional_stream: HashMap::new(), + control_stream_id: None, + peer_control_stream_id: None, + qpack_encode_stream_id: None, + qpack_decode_stream_id: None, + peer_qpack_encode_stream_id: None, + peer_qpack_decode_stream_id: None, + goaway_id: None, + peer_goaway_id: None, + pending_send: VecDeque::new(), + pending_recv: HashSet::new(), + resume_recv: VecDeque::new(), + pending_concurrency: VecDeque::new(), + pending_channel: HashSet::new(), + working_stream_num: 0, + max_stream_concurrency: DEFAULT_MAX_CONCURRENT_STREAMS, + next_uni_stream_id: AtomicU64::new(HTTP3_FIRST_UNI_STREAM_ID), + next_bidi_stream_id: AtomicU64::new(HTTP3_FIRST_BIDI_STREAM_ID), + } + } + + pub(crate) fn new_unidirectional_stream( + &mut self, + header: Frame, + data: BodyDataRef, + rx: BoundedSender, + ) -> Result<(), DispatchErrorKind> { + let id = + self.get_next_bidi_stream_id() + .ok_or(DispatchErrorKind::H3(H3Error::Connection( + H3ErrorCode::H3GeneralProtocolError, + )))?; + self.bidirectional_stream + .insert(id, BidirectionalStream::new(rx, header, data)); + if self.reach_max_concurrency() { + self.push_back_pending_concurrency(id); + } else { + self.push_back_pending_send(id); + self.increase_current_concurrency(); + } + Ok(()) + } + + pub(crate) fn send_frame( + &mut self, + cx: &mut Context<'_>, + id: u64, + frame: Frame, + ) -> Result<(), DispatchErrorKind> { + if let Some(stream) = self.bidirectional_stream.get_mut(&id) { + match stream.state { + H3StreamState::Sending => { + if let Payload::Headers(_) = frame.payload() { + stream.state = H3StreamState::HeadersReceived; + } else { + return Err(DispatchErrorKind::H3(H3Error::Connection( + H3ErrorCode::H3FrameUnexpected, + ))); + } + } + H3StreamState::HeadersReceived => { + if let Payload::Headers(_) = frame.payload() { + return Err(DispatchErrorKind::H3(H3Error::Connection( + H3ErrorCode::H3FrameUnexpected, + ))); + } else { + stream.state = H3StreamState::BodyReceived; + } + } + H3StreamState::BodyReceived => { + if let Payload::Headers(_) = frame.payload() { + stream.state = H3StreamState::TrailerReceived; + } + } + H3StreamState::TrailerReceived => { + return Err(DispatchErrorKind::H3(H3Error::Connection( + H3ErrorCode::H3FrameUnexpected, + ))); + } + H3StreamState::Shutdown => { + // stream has been shutdown, drop frame + return Ok(()); + } + } + if stream.curr_message.is_some() { + stream.pending_message.push_back(RespMessage::Output(frame)); + return Ok(()); + } + if let Poll::Ready(ret) = stream.transmit_message(cx, RespMessage::Output(frame)) { + ret + } else { + self.stream_pend_channel(id); + Ok(()) + } + } else { + Err(DispatchErrorKind::ChannelClosed) + } + } + + pub(crate) fn send_error( + &mut self, + cx: &mut Context<'_>, + id: u64, + error: DispatchErrorKind, + ) -> Result<(), DispatchErrorKind> { + if let Some(stream) = self.bidirectional_stream.get_mut(&id) { + stream.pending_message.clear(); + if let Poll::Ready(ret) = stream.transmit_message(cx, RespMessage::OutputExit(error)) { + ret + } else { + self.stream_pend_channel(id); + Ok(()) + } + } else { + Err(DispatchErrorKind::ChannelClosed) + } + } + + pub(crate) fn control_stream_id(&mut self) -> Option { + if self.control_stream_id.is_some() { + self.control_stream_id + } else { + self.control_stream_id = self.get_next_uni_stream_id(); + self.control_stream_id + } + } + + pub(crate) fn qpack_decode_stream_id(&mut self) -> Option { + if self.qpack_decode_stream_id.is_some() { + self.qpack_decode_stream_id + } else { + self.qpack_decode_stream_id = self.get_next_uni_stream_id(); + self.qpack_decode_stream_id + } + } + + pub(crate) fn qpack_encode_stream_id(&mut self) -> Option { + if self.qpack_encode_stream_id.is_some() { + self.qpack_encode_stream_id + } else { + self.qpack_encode_stream_id = self.get_next_uni_stream_id(); + self.qpack_encode_stream_id + } + } + + pub(crate) fn peer_qpack_encode_stream_id(&self) -> Option { + self.peer_qpack_encode_stream_id + } + + pub(crate) fn peer_goaway_id(&self) -> Option { + self.peer_goaway_id + } + + #[allow(unused)] + pub(crate) fn goaway_id(&self) -> Option { + self.goaway_id + } + + pub(crate) fn peer_control_stream_id(&self) -> Option { + self.peer_control_stream_id + } + + pub(crate) fn peer_qpack_decode_stream_id(&self) -> Option { + self.peer_qpack_decode_stream_id + } + + pub(crate) fn set_peer_qpack_encode_stream_id( + &mut self, + id: u64, + ) -> Result<(), DispatchErrorKind> { + if let Some(old_id) = self.peer_qpack_encode_stream_id { + if old_id != id { + return Err(DispatchErrorKind::H3(H3Error::Connection( + H3ErrorCode::H3StreamCreationError, + ))); + } + } else { + self.peer_qpack_encode_stream_id = Some(id); + } + Ok(()) + } + + pub(crate) fn set_peer_control_stream_id(&mut self, id: u64) -> Result<(), DispatchErrorKind> { + if let Some(old_id) = self.peer_control_stream_id { + if old_id != id { + return Err(DispatchErrorKind::H3(H3Error::Connection( + H3ErrorCode::H3StreamCreationError, + ))); + } + } else { + self.peer_control_stream_id = Some(id); + } + Ok(()) + } + + pub(crate) fn set_peer_qpack_decode_stream_id( + &mut self, + id: u64, + ) -> Result<(), DispatchErrorKind> { + if let Some(old_id) = self.peer_qpack_decode_stream_id { + if old_id != id { + return Err(DispatchErrorKind::H3(H3Error::Connection( + H3ErrorCode::H3StreamCreationError, + ))); + } + } else { + self.peer_qpack_decode_stream_id = Some(id); + } + Ok(()) + } + + #[allow(unused)] + pub(crate) fn set_goaway_id(&mut self, id: u64) -> Result<(), DispatchErrorKind> { + if let Some(old_goaway_id) = self.goaway_id { + if id > old_goaway_id { + return Err(DispatchErrorKind::H3(H3Error::Connection( + H3ErrorCode::H3InternalError, + ))); + } + } + self.goaway_id = Some(id); + Ok(()) + } + + pub(crate) fn get_header(&mut self, id: u64) -> Result, DispatchErrorKind> { + if let Some(stream) = self.bidirectional_stream.get_mut(&id) { + Ok(stream.header.take()) + } else { + Err(DispatchErrorKind::H3(H3Error::Connection( + H3ErrorCode::H3InternalError, + ))) + } + } + + pub(crate) fn frame_acceptable(&mut self, id: u64) -> bool { + !self.is_stream_recv_pending(id) && !self.is_stream_channel_pending(id) + } + + pub(crate) fn decrease_current_concurrency(&mut self) { + self.working_stream_num -= 1; + } + + pub(crate) fn increase_current_concurrency(&mut self) { + self.working_stream_num += 1; + } + + pub(crate) fn current_concurrency(&mut self) -> u32 { + self.working_stream_num + } + + pub(crate) fn reach_max_concurrency(&mut self) -> bool { + self.working_stream_num >= self.max_stream_concurrency + } + + pub(crate) fn push_back_pending_send(&mut self, id: u64) { + self.pending_send.push_back(id); + } + + pub(crate) fn next_stream(&mut self) -> Option { + self.pending_send.pop_front() + } + + pub(crate) fn pending_stream_len(&mut self) -> u64 { + self.pending_send.len() as u64 + } + + pub(crate) fn push_back_pending_concurrency(&mut self, id: u64) { + self.pending_concurrency.push_back(id); + } + + pub(crate) fn pop_front_pending_concurrency(&mut self) -> Option { + self.pending_concurrency.pop_front() + } + + pub(crate) fn stream_pend_channel(&mut self, id: u64) { + self.pending_channel.insert(id); + } + + pub(crate) fn is_stream_channel_pending(&self, id: u64) -> bool { + self.pending_channel.contains(&id) + } + + pub(crate) fn try_consume_pending_concurrency(&mut self) { + while !self.reach_max_concurrency() { + match self.pop_front_pending_concurrency() { + Some(id) => { + self.push_back_pending_send(id); + self.increase_current_concurrency(); + } + None => { + return; + } + } + } + } + + pub(crate) fn get_next_uni_stream_id(&self) -> Option { + let id = self.next_uni_stream_id.fetch_add(4, Ordering::Relaxed); + if id > HTTP3_MAX_STREAM_ID { + None + } else { + Some(id) + } + } + + pub(crate) fn get_next_bidi_stream_id(&self) -> Option { + let id = self.next_bidi_stream_id.fetch_add(4, Ordering::Relaxed); + if id > HTTP3_MAX_STREAM_ID { + None + } else { + Some(id) + } + } + + pub(crate) fn pend_stream_recv(&mut self, id: u64) { + self.pending_recv.insert(id); + } + + pub(crate) fn resume_stream_recv(&mut self, id: u64) { + self.pending_recv.remove(&id); + self.resume_recv.push_back(id); + } + + pub(crate) fn is_stream_recv_pending(&self, id: u64) -> bool { + self.pending_recv.contains(&id) + } + + pub(crate) fn get_resume_stream_id(&mut self) -> Option { + self.resume_recv.pop_front() + } + + pub(crate) fn poll_sized_data( + &mut self, + cx: &mut Context<'_>, + id: u64, + buf: &mut [u8], + ) -> Result { + let stream = self + .bidirectional_stream + .get_mut(&id) + .ok_or(DispatchErrorKind::H3(H3Error::Connection( + H3ErrorCode::H3InternalError, + )))?; + + if stream.state == H3StreamState::Shutdown { + return Ok(DataReadState::Closed); + } + + match stream.data.poll_read(cx, buf) { + Poll::Ready(Some(size)) => { + if size > 0 { + let data_vec = Vec::from(&buf[..size]); + Ok(DataReadState::Ready(Box::new(Frame::new( + DATA_FRAME_TYPE, + Payload::Data(Data::new(data_vec)), + )))) + } else { + Ok(DataReadState::Finish) + } + } + Poll::Ready(None) => Err(DispatchErrorKind::H3(H3Error::Connection( + H3ErrorCode::H3InternalError, + ))), + Poll::Pending => { + self.push_back_pending_send(id); + Ok(DataReadState::Pending) + } + } + } + + pub(crate) fn shutdown_stream(&mut self, cx: &mut Context<'_>, id: u64, err: &H3ErrorCode) { + let Some(stream) = self.bidirectional_stream.get_mut(&id) else { + return; + }; + if stream + .transmit_message( + cx, + RespMessage::OutputExit(DispatchErrorKind::H3(H3Error::Stream(id, *err))), + ) + .is_pending() + { + self.stream_pend_channel(id); + } + self.decrease_current_concurrency(); + // stream.header = None; + // stream.pending_frame.clear(); + // stream.data.clear(); + // stream.state = H3StreamState::Shutdown; + } + + pub(crate) fn goaway( + &mut self, + cx: &mut Context<'_>, + goaway_id: u64, + ) -> Result<(), DispatchErrorKind> { + if let Some(old_goaway_id) = self.peer_goaway_id() { + if goaway_id > old_goaway_id { + return Err(DispatchErrorKind::H3(H3Error::Connection( + H3ErrorCode::H3IdError, + ))); + } + } + if QUICStreamType::from(goaway_id) != QUICStreamType::ClientInitialBidirectional { + return Err(DispatchErrorKind::H3(H3Error::Connection( + H3ErrorCode::H3IdError, + ))); + } + self.goaway_id = Some(goaway_id); + let mut pending_channels = Vec::new(); + for (id, stream) in self.bidirectional_stream.iter_mut() { + if id > &goaway_id { + stream.state = H3StreamState::Shutdown; + stream.header = None; + stream.pending_message.clear(); + stream.data.clear(); + if stream + .transmit_message( + cx, + RespMessage::OutputExit(DispatchErrorKind::GoawayReceived), + ) + .is_pending() + { + pending_channels.push(*id); + } + } + } + for id in pending_channels { + self.stream_pend_channel(id); + } + Ok(()) + } + + pub(crate) fn shutdown(&mut self, cx: &mut Context<'_>, err: &DispatchErrorKind) { + let mut pending_channels = Vec::new(); + for (id, stream) in self.bidirectional_stream.iter_mut() { + stream.state = H3StreamState::Shutdown; + stream.header = None; + stream.pending_message.clear(); + stream.data.clear(); + if stream + .transmit_message(cx, RespMessage::OutputExit(err.clone())) + .is_pending() + { + pending_channels.push(*id); + } + } + for id in pending_channels { + self.stream_pend_channel(id); + } + } + + pub(crate) fn set_encoding( + &mut self, + id: u64, + encoding: bool, + ) -> Result<(), DispatchErrorKind> { + if let Some(stream) = self.bidirectional_stream.get_mut(&id) { + stream.encoding = encoding; + Ok(()) + } else { + Err(DispatchErrorKind::H3(H3Error::Connection( + H3ErrorCode::H3InternalError, + ))) + } + } + + pub(crate) fn encoding(&mut self, id: u64) -> Result { + if let Some(stream) = self.bidirectional_stream.get_mut(&id) { + Ok(stream.encoding) + } else { + Err(DispatchErrorKind::H3(H3Error::Connection( + H3ErrorCode::H3InternalError, + ))) + } + } + + pub(crate) fn finish_stream( + &mut self, + cx: &mut Context<'_>, + id: u64, + ) -> Result<(), DispatchErrorKind> { + if QUICStreamType::from(id) != QUICStreamType::ClientInitialBidirectional { + return if Some(id) == self.peer_control_stream_id() + || Some(id) == self.peer_qpack_encode_stream_id() + || Some(id) == self.peer_qpack_decode_stream_id() + { + Err(DispatchErrorKind::H3(H3Error::Connection( + H3ErrorCode::H3ClosedCriticalStream, + ))) + } else { + Ok(()) + }; + } + self.decrease_current_concurrency(); + if let Some(stream) = self.bidirectional_stream.get_mut(&id) { + stream.state = H3StreamState::Shutdown; + if stream.curr_message.is_none() { + if let Poll::Ready(ret) = stream.transmit_message( + cx, + RespMessage::OutputExit(DispatchErrorKind::StreamFinished), + ) { + ret + } else { + self.stream_pend_channel(id); + Ok(()) + } + } else { + stream + .pending_message + .push_back(RespMessage::OutputExit(DispatchErrorKind::StreamFinished)); + Ok(()) + } + } else { + Err(DispatchErrorKind::H3(H3Error::Connection( + H3ErrorCode::H3InternalError, + ))) + } + } + + pub(crate) fn poll_blocked_message( + &mut self, + cx: &mut Context<'_>, + ) -> Poll> { + let mut new_set = HashSet::new(); + for id in &self.pending_channel { + let Some(stream) = self.bidirectional_stream.get_mut(id) else { + return Poll::Ready(Err(DispatchErrorKind::H3(H3Error::Connection( + H3ErrorCode::H3InternalError, + )))); + }; + if let Some(mut task) = stream.curr_message.take() { + match task.as_mut().poll(cx) { + Poll::Ready(Ok(_)) => {} + Poll::Ready(Err(_)) => { + // todo: shutdown + stream.state = H3StreamState::Shutdown; + } + Poll::Pending => { + stream.curr_message = Some(task); + new_set.insert(*id); + continue; + } + } + } + while let Some(message) = stream.pending_message.pop_front() { + match stream.transmit_message(cx, message) { + Poll::Ready(Ok(())) => {} + Poll::Pending => { + new_set.insert(*id); + break; + } + Poll::Ready(Err(_)) => { + stream.state = H3StreamState::Shutdown; + break; + } + } + } + } + self.pending_channel = new_set; + Poll::Pending + } +} + +pub(crate) enum DataReadState { + Closed, + // Wait for poll_read or wait for window. + Pending, + Ready(Box), + Finish, +} diff --git a/ylong_http_client/src/util/mod.rs b/ylong_http_client/src/util/mod.rs index 58a5562..f4a978c 100644 --- a/ylong_http_client/src/util/mod.rs +++ b/ylong_http_client/src/util/mod.rs @@ -29,23 +29,29 @@ pub(crate) mod redirect; #[cfg(feature = "async")] pub(crate) mod request; -#[cfg(feature = "__c_openssl")] +#[cfg(feature = "__tls")] pub(crate) mod c_openssl; #[cfg(any(feature = "http1_1", feature = "http2"))] pub(crate) mod dispatcher; +#[cfg(feature = "http3")] +pub(crate) mod alt_svc; +#[cfg(any(feature = "http3", feature = "http2"))] +pub(crate) mod data_ref; #[cfg(feature = "http2")] pub(crate) mod h2; +#[cfg(feature = "http3")] +pub(crate) mod h3; #[cfg(all(test, feature = "ylong_base"))] pub(crate) mod test_utils; -#[cfg(feature = "__c_openssl")] +#[cfg(feature = "__tls")] pub use c_openssl::{ Cert, Certificate, PubKeyPins, PubKeyPinsBuilder, TlsConfig, TlsConfigBuilder, TlsFileType, TlsVersion, }; -#[cfg(feature = "__c_openssl")] +#[cfg(feature = "__tls")] pub(crate) use config::{AlpnProtocol, AlpnProtocolList}; #[cfg(feature = "__tls")] pub use config::{CertVerifier, ServerCerts}; diff --git a/ylong_http_client/src/util/redirect.rs b/ylong_http_client/src/util/redirect.rs index 258ca37..549d7ce 100644 --- a/ylong_http_client/src/util/redirect.rs +++ b/ylong_http_client/src/util/redirect.rs @@ -37,6 +37,7 @@ impl Redirect { } } + // todo: check h3? pub(crate) fn redirect( &self, request: &mut Request, diff --git a/ylong_http_client/tests/sdv_async_client_build.rs b/ylong_http_client/tests/sdv_async_client_build.rs index 848dd01..63720b2 100644 --- a/ylong_http_client/tests/sdv_async_client_build.rs +++ b/ylong_http_client/tests/sdv_async_client_build.rs @@ -37,7 +37,6 @@ fn sdv_client_tls_builder() { .add_root_certificate(Certificate::from_pem(b"cert").unwrap()) .tls_ca_file("ca.crt") .tls_cipher_list("DEFAULT:!aNULL:!eNULL:!MD5:!3DES:!DES:!RC4:!IDEA:!SEED:!aDSS:!SRP:!PSK") - .tls_cipher_suites("DEFAULT:!aNULL:!eNULL:!MD5:!3DES:!DES:!RC4:!IDEA:!SEED:!aDSS:!SRP:!PSK") .tls_built_in_root_certs(false) .danger_accept_invalid_certs(false) .danger_accept_invalid_hostnames(false) diff --git a/ylong_http_client/tests/sdv_async_https_pinning.rs b/ylong_http_client/tests/sdv_async_https_pinning.rs index 794b6ed..05e85de 100644 --- a/ylong_http_client/tests/sdv_async_https_pinning.rs +++ b/ylong_http_client/tests/sdv_async_https_pinning.rs @@ -14,7 +14,7 @@ #![cfg(all( feature = "async", feature = "http1_1", - feature = "__c_openssl", + feature = "__tls", feature = "tokio_base" ))] -- Gitee From 3e661801976f420ec9d82a376523c0ef600884e8 Mon Sep 17 00:00:00 2001 From: huaxin Date: Thu, 19 Sep 2024 10:35:09 +0800 Subject: [PATCH 8/8] =?UTF-8?q?=E4=BF=AE=E6=94=B9content-length=E7=9A=84?= =?UTF-8?q?=E5=AD=98=E5=82=A8=E7=B1=BB=E5=9E=8B=E4=B8=BAu64=EF=BC=8C?= =?UTF-8?q?=E4=BF=9D=E8=AF=81=E5=A4=A7=E6=96=87=E4=BB=B6=E5=8F=AF=E4=BB=A5?= =?UTF-8?q?=E6=AD=A3=E5=B8=B8=E4=B8=8B=E8=BD=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: huaxin Change-Id: I8db61e7cffbbfbf66f1389c6fd3fe791dbf136fe --- ylong_http/src/body/text.rs | 11 ++++++----- ylong_http_client/src/async_impl/http_body.rs | 2 +- ylong_http_client/src/sync_impl/conn/http1.rs | 2 +- ylong_http_client/src/sync_impl/http_body.rs | 4 ++-- ylong_http_client/src/util/normalizer.rs | 4 ++-- 5 files changed, 12 insertions(+), 11 deletions(-) diff --git a/ylong_http/src/body/text.rs b/ylong_http/src/body/text.rs index 5d96dbe..2c34329 100644 --- a/ylong_http/src/body/text.rs +++ b/ylong_http/src/body/text.rs @@ -276,7 +276,7 @@ impl async_impl::Body for TextBody TextBodyDecoder { + pub fn new(length: u64) -> TextBodyDecoder { TextBodyDecoder { left: length } } @@ -348,12 +348,13 @@ impl TextBodyDecoder { return (Text::complete(&buf[..0]), buf); } - let size = min(self.left, buf.len()); + let size = min(self.left, buf.len() as u64); self.left -= size; + let end = size as usize; if self.left == 0 { - (Text::complete(&buf[..size]), &buf[size..]) + (Text::complete(&buf[..end]), &buf[end..]) } else { - (Text::partial(&buf[..size]), &buf[size..]) + (Text::partial(&buf[..end]), &buf[end..]) } } } diff --git a/ylong_http_client/src/async_impl/http_body.rs b/ylong_http_client/src/async_impl/http_body.rs index a8fa85a..f504a6b 100644 --- a/ylong_http_client/src/async_impl/http_body.rs +++ b/ylong_http_client/src/async_impl/http_body.rs @@ -275,7 +275,7 @@ struct Text { impl Text { pub(crate) fn new( - len: usize, + len: u64, pre: &[u8], io: BoxStreamData, interceptors: Arc, diff --git a/ylong_http_client/src/sync_impl/conn/http1.rs b/ylong_http_client/src/sync_impl/conn/http1.rs index 43a5437..f92ae41 100644 --- a/ylong_http_client/src/sync_impl/conn/http1.rs +++ b/ylong_http_client/src/sync_impl/conn/http1.rs @@ -106,7 +106,7 @@ where .headers .get("Content-Length") .map(|v| v.to_string().unwrap_or(String::new())) - .and_then(|s| s.parse::().ok()); + .and_then(|s| s.parse::().ok()); let is_trailer = part.headers.get("Trailer").is_some(); diff --git a/ylong_http_client/src/sync_impl/http_body.rs b/ylong_http_client/src/sync_impl/http_body.rs index 4bafd28..1574b8d 100644 --- a/ylong_http_client/src/sync_impl/http_body.rs +++ b/ylong_http_client/src/sync_impl/http_body.rs @@ -56,7 +56,7 @@ impl HttpBody { Self { kind: Kind::Empty } } - pub(crate) fn text(len: usize, pre: &[u8], io: BoxStreamData) -> Self { + pub(crate) fn text(len: u64, pre: &[u8], io: BoxStreamData) -> Self { Self { kind: Kind::Text(Text::new(len, pre, io)), } @@ -83,7 +83,7 @@ struct Text { } impl Text { - pub(crate) fn new(len: usize, pre: &[u8], io: BoxStreamData) -> Self { + pub(crate) fn new(len: u64, pre: &[u8], io: BoxStreamData) -> Self { Self { decoder: TextBodyDecoder::new(len), pre: (!pre.is_empty()).then_some(Cursor::new(pre.to_vec())), diff --git a/ylong_http_client/src/util/normalizer.rs b/ylong_http_client/src/util/normalizer.rs index d2d4812..00b71e8 100644 --- a/ylong_http_client/src/util/normalizer.rs +++ b/ylong_http_client/src/util/normalizer.rs @@ -162,7 +162,7 @@ impl<'a> BodyLengthParser<'a> { if content_length.is_some() { let content_length_valid = content_length .and_then(|v| v.to_string().ok()) - .and_then(|s| s.parse::().ok()); + .and_then(|s| s.parse::().ok()); return match content_length_valid { // If `content-length` is 0, the io stream cannot be read, @@ -180,7 +180,7 @@ impl<'a> BodyLengthParser<'a> { pub(crate) enum BodyLength { #[cfg(feature = "http1_1")] Chunk, - Length(usize), + Length(u64), Empty, UntilClose, } -- Gitee