From 1f4640d4cecea72c8687ff33df2e544c9c77c24a Mon Sep 17 00:00:00 2001 From: Tiga Ultraman Date: Thu, 8 May 2025 14:35:56 +0800 Subject: [PATCH] add rate limit Signed-off-by: Tiga Ultraman --- ylong_http_client/src/async_impl/client.rs | 49 ++ .../src/async_impl/conn/http1.rs | 129 +++-- .../src/async_impl/conn/http2.rs | 67 ++- .../src/async_impl/conn/http3.rs | 64 ++- ylong_http_client/src/async_impl/http_body.rs | 13 +- ylong_http_client/src/async_impl/pool.rs | 56 ++- ylong_http_client/src/lib.rs | 2 + .../src/util/c_openssl/ffi/stack.rs | 12 +- .../src/util/c_openssl/ssl/stream.rs | 6 +- ylong_http_client/src/util/c_openssl/stack.rs | 1 - .../src/util/c_openssl/verify/pinning.rs | 6 +- ylong_http_client/src/util/config/http.rs | 6 +- ylong_http_client/src/util/data_ref.rs | 29 +- ylong_http_client/src/util/dispatcher.rs | 21 +- ylong_http_client/src/util/h2/manager.rs | 44 +- ylong_http_client/src/util/h2/streams.rs | 15 +- ylong_http_client/src/util/h3/streams.rs | 4 +- ylong_http_client/src/util/mod.rs | 1 + ylong_http_client/src/util/pool.rs | 19 +- ylong_http_client/src/util/progress/mod.rs | 16 + .../src/util/progress/rate_limit.rs | 441 ++++++++++++++++++ .../tests/sdv_async_https_pinning.rs | 3 +- 22 files changed, 889 insertions(+), 115 deletions(-) create mode 100644 ylong_http_client/src/util/progress/mod.rs create mode 100644 ylong_http_client/src/util/progress/rate_limit.rs diff --git a/ylong_http_client/src/async_impl/client.rs b/ylong_http_client/src/async_impl/client.rs index a451c75..12ba380 100644 --- a/ylong_http_client/src/async_impl/client.rs +++ b/ylong_http_client/src/async_impl/client.rs @@ -461,6 +461,55 @@ impl ClientBuilder { self } + /// Sets the maximum number of bytes per second allowed for data transfer. + /// + /// By default, there is no limit. + /// + /// # Examples + /// + /// ``` + /// use ylong_http_client::async_impl::ClientBuilder; + /// + /// let builder = ClientBuilder::new().max_speed_limit(5); + /// ``` + pub fn max_speed_limit(mut self, rate: u64) -> Self { + self.http.speed_config.set_max_rate(rate); + self + } + + /// Sets the minimum number of bytes per second allowed for data transfer. + /// + /// By default, there is no limit. + /// + /// # Examples + /// + /// ``` + /// use ylong_http_client::async_impl::ClientBuilder; + /// + /// let builder = ClientBuilder::new().min_speed_limit(5); + /// ``` + pub fn min_speed_limit(mut self, rate: u64) -> Self { + self.http.speed_config.set_min_rate(rate); + self + } + + /// Sets the maximum time that the speed is allowed to be below + /// min_speed_limit, beyond which the abort is executed. + /// + /// By default, there is no limit. + /// + /// # Examples + /// + /// ``` + /// use ylong_http_client::async_impl::ClientBuilder; + /// + /// let builder = ClientBuilder::new().min_speed_limit(5); + /// ``` + pub fn min_speed_interval(mut self, seconds: u64) -> Self { + self.http.speed_config.set_min_speed_interval(seconds); + self + } + /// Adds a `Interceptor` to the `Client`. /// /// # Examples diff --git a/ylong_http_client/src/async_impl/conn/http1.rs b/ylong_http_client/src/async_impl/conn/http1.rs index c949a68..8fb4470 100644 --- a/ylong_http_client/src/async_impl/conn/http1.rs +++ b/ylong_http_client/src/async_impl/conn/http1.rs @@ -11,6 +11,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::future::Future; use std::mem::take; use std::pin::Pin; use std::sync::Arc; @@ -28,12 +29,13 @@ use super::StreamData; use crate::async_impl::request::Message; use crate::async_impl::{HttpBody, Request, Response}; use crate::error::HttpClientError; -use crate::runtime::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf}; +use crate::runtime::{poll_fn, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf}; use crate::util::config::HttpVersion; use crate::util::dispatcher::http1::Http1Conn; use crate::util::information::ConnInfo; use crate::util::interceptor::Interceptors; use crate::util::normalizer::BodyLengthParser; +use crate::ErrorKind::BodyTransfer; const TEMP_BUF_SIZE: usize = 16 * 1024; @@ -66,33 +68,28 @@ where let (part, pre) = { let mut decoder = ResponseDecoder::new(); loop { - let size = match conn.raw_mut().read(buf.as_mut_slice()).await { - Ok(0) => { - conn.shutdown(); - return err_from_msg!(Request, "Tcp closed"); - } - Ok(size) => { - if message - .request - .ref_mut() - .time_group_mut() - .transfer_end_time() - .is_none() - { - message - .request - .ref_mut() - .time_group_mut() - .set_transfer_end(Instant::now()) - } - size + let size = poll_fn(|cx| { + if conn.speed_controller.poll_recv_pending_timeout(cx) { + return Poll::Ready(Err(HttpClientError::from_str( + BodyTransfer, + "Below low speed limit", + ))); } - Err(e) => { - conn.shutdown(); - return err_from_io!(Request, e); + let result = { + let mut read_fut = Box::pin(read_status_line( + &mut conn, + message.request.ref_mut(), + buf.as_mut_slice(), + )); + read_fut.as_mut().poll(cx)? + }; + if let Poll::Ready(filled) = result { + conn.speed_controller.reset_recv_pending_timeout(); + return Poll::Ready(Ok(filled)); } - }; - + Poll::Pending + }) + .await?; message.interceptor.intercept_output(&buf[..size])?; match decoder.decode(&buf[..size]) { Ok(None) => {} @@ -108,6 +105,32 @@ where decode_response(message, part, conn, pre) } +async fn read_status_line( + conn: &mut Http1Conn, + request: &mut Request, + buf: &mut [u8], +) -> Result +where + S: AsyncRead + Sync + Send + Unpin + 'static, +{ + match conn.raw_mut().read(buf).await { + Ok(0) => { + conn.shutdown(); + err_from_msg!(Request, "Tcp closed") + } + Ok(size) => { + if request.time_group_mut().transfer_end_time().is_none() { + request.time_group_mut().set_transfer_end(Instant::now()) + } + Ok(size) + } + Err(e) => { + conn.shutdown(); + err_from_io!(Request, e) + } + } +} + async fn encode_various_body( request: &mut Request, conn: &mut Http1Conn, @@ -260,10 +283,31 @@ where end_body = end; } if written == buf.len() || end_body { - if let Err(e) = conn.raw_mut().write_all(&buf[..written]).await { + conn.speed_controller.init_min_send_if_not_start(); + conn.speed_controller.init_max_send_if_not_start(); + let write_res = poll_fn(|cx| { + if conn.speed_controller.poll_send_pending_timeout(cx) { + return Poll::Ready(Err(HttpClientError::from_str( + BodyTransfer, + "Below low speed limit", + ))); + } + let mut write_fut = Box::pin(conn.raw_mut().write_all(&buf[..written])); + let write_poll = write_fut.as_mut().poll(cx); + if let Poll::Ready(Ok(_)) = write_poll { + conn.speed_controller.reset_send_pending_timeout(); + } + write_poll.map_err(|e| HttpClientError::from_error(BodyTransfer, e)) + }) + .await; + if let Err(e) = write_res { conn.shutdown(); - return err_from_io!(BodyTransfer, e); + return Err(e); } + if conn.speed_controller.need_limit_max_send_speed() { + conn.speed_controller.max_send_speed_limit(written).await; + } + conn.speed_controller.min_send_speed_limit(written)?; written = 0; } } @@ -303,7 +347,34 @@ impl AsyncRead for Http1Conn { cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { - Pin::new(self.raw_mut()).poll_read(cx, buf) + if self.speed_controller.poll_recv_pending_timeout(cx) { + return Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::TimedOut, + HttpClientError::from_str(BodyTransfer, "Below low speed limit"), + ))); + } + self.speed_controller.init_min_recv_if_not_start(); + if self + .speed_controller + .poll_max_recv_delay_time(cx) + .is_pending() + { + return Poll::Pending; + } + self.speed_controller.init_max_recv_if_not_start(); + match Pin::new(self.raw_mut()).poll_read(cx, buf) { + Poll::Ready(Ok(_)) => { + let filled: usize = buf.filled().len(); + self.speed_controller + .min_recv_speed_limit(filled) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?; + self.speed_controller.delay_max_recv_speed_limit(filled); + self.speed_controller.reset_recv_pending_timeout(); + Poll::Ready(Ok(())) + } + Poll::Ready(Err(e)) => Poll::Ready(Err(e)), + Poll::Pending => Poll::Pending, + } } } diff --git a/ylong_http_client/src/async_impl/conn/http2.rs b/ylong_http_client/src/async_impl/conn/http2.rs index 9fd6a1a..e9ecb28 100644 --- a/ylong_http_client/src/async_impl/conn/http2.rs +++ b/ylong_http_client/src/async_impl/conn/http2.rs @@ -38,6 +38,7 @@ use crate::util::data_ref::BodyDataRef; use crate::util::dispatcher::http2::Http2Conn; use crate::util::h2::RequestWrapper; use crate::util::normalizer::BodyLengthParser; +use crate::ErrorKind::BodyTransfer; const UNUSED_FLAG: u8 = 0x0; @@ -57,7 +58,7 @@ where let is_end_stream = message.request.ref_mut().body().is_empty(); let (flag, payload) = build_headers_payload(part, is_end_stream) .map_err(|e| HttpClientError::from_error(ErrorKind::Request, e))?; - let data = BodyDataRef::new(message.request.clone()); + let data = BodyDataRef::new(message.request.clone(), conn.speed_controller.clone()); let stream = RequestWrapper { flag, payload, @@ -368,25 +369,79 @@ impl AsyncRead for TextIo { cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { - let text_io = self.get_mut(); let mut buf = HttpReadBuf { buf }; - + let text_io = self.get_mut(); if buf.remaining() == 0 || text_io.is_closed { return Poll::Ready(Ok(())); } + if text_io + .handle + .speed_controller + .poll_recv_pending_timeout(cx) + { + return Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::TimedOut, + HttpClientError::from_str(BodyTransfer, "Below low speed limit"), + ))); + } + // Min speed contains the max speed limit sleep time. + text_io.handle.speed_controller.init_min_recv_if_not_start(); + if text_io + .handle + .speed_controller + .poll_max_recv_delay_time(cx) + .is_pending() + { + return Poll::Pending; + } + text_io.handle.speed_controller.init_max_recv_if_not_start(); while buf.remaining() != 0 { if let Some(result) = Self::read_remaining_data(text_io, &mut buf) { - return result; + return match result { + Poll::Ready(Ok(_)) => { + let filled: usize = buf.filled().len(); + text_io + .handle + .speed_controller + .min_recv_speed_limit(filled) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?; + text_io + .handle + .speed_controller + .delay_max_recv_speed_limit(filled); + text_io.handle.speed_controller.reset_recv_pending_timeout(); + Poll::Ready(Ok(())) + } + Poll::Ready(Err(e)) => Poll::Ready(Err(e)), + Poll::Pending => Poll::Pending, + }; } let poll_result = text_io .handle .receiver .poll_recv(cx) - .map_err(|_e| std::io::Error::from(std::io::ErrorKind::Other))?; + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?; if let Some(result) = Self::match_channel_message(poll_result, text_io, &mut buf) { - return result; + return match result { + Poll::Ready(Ok(_)) => { + let filled: usize = buf.filled().len(); + text_io + .handle + .speed_controller + .min_recv_speed_limit(filled) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?; + text_io + .handle + .speed_controller + .delay_max_recv_speed_limit(filled); + text_io.handle.speed_controller.reset_recv_pending_timeout(); + Poll::Ready(Ok(())) + } + Poll::Ready(Err(e)) => Poll::Ready(Err(e)), + Poll::Pending => Poll::Pending, + }; } } Poll::Ready(Ok(())) diff --git a/ylong_http_client/src/async_impl/conn/http3.rs b/ylong_http_client/src/async_impl/conn/http3.rs index f1d712b..0d952d0 100644 --- a/ylong_http_client/src/async_impl/conn/http3.rs +++ b/ylong_http_client/src/async_impl/conn/http3.rs @@ -35,6 +35,7 @@ use crate::util::config::HttpVersion; use crate::util::data_ref::BodyDataRef; use crate::util::dispatcher::http3::{DispatchErrorKind, Http3Conn, RequestWrapper, RespMessage}; use crate::util::normalizer::BodyLengthParser; +use crate::ErrorKind::BodyTransfer; use crate::{ErrorKind, HttpClientError}; pub(crate) async fn request( @@ -52,7 +53,7 @@ where // TODO Implement trailer. let headers = build_headers_frame(part) .map_err(|e| HttpClientError::from_error(ErrorKind::Request, e))?; - let data = BodyDataRef::new(message.request.clone()); + let data = BodyDataRef::new(message.request.clone(), conn.speed_controller.clone()); let stream = RequestWrapper { header: headers, data, @@ -303,15 +304,51 @@ impl AsyncRead for TextIo { cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { - let text_io = self.get_mut(); let mut buf = HttpReadBuf { buf }; - + let text_io = self.get_mut(); if buf.remaining() == 0 || text_io.is_closed { return Poll::Ready(Ok(())); } + if text_io + .handle + .speed_controller + .poll_recv_pending_timeout(cx) + { + return Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::TimedOut, + HttpClientError::from_str(BodyTransfer, "Below low speed limit"), + ))); + } + text_io.handle.speed_controller.init_min_recv_if_not_start(); + if text_io + .handle + .speed_controller + .poll_max_recv_delay_time(cx) + .is_pending() + { + return Poll::Pending; + } + text_io.handle.speed_controller.init_max_recv_if_not_start(); while buf.remaining() != 0 { if let Some(result) = Self::read_remaining_data(text_io, &mut buf) { - return result; + return match result { + Poll::Ready(Ok(_)) => { + let filled: usize = buf.filled().len(); + text_io + .handle + .speed_controller + .min_recv_speed_limit(filled) + .map_err(|_e| std::io::Error::from(std::io::ErrorKind::Other))?; + text_io + .handle + .speed_controller + .delay_max_recv_speed_limit(filled); + text_io.handle.speed_controller.reset_recv_pending_timeout(); + Poll::Ready(Ok(())) + } + Poll::Ready(Err(e)) => Poll::Ready(Err(e)), + Poll::Pending => Poll::Pending, + }; } let poll_result = text_io @@ -321,7 +358,24 @@ impl AsyncRead for TextIo { .map_err(|_e| std::io::Error::from(std::io::ErrorKind::ConnectionAborted))?; if let Some(result) = Self::match_channel_message(poll_result, text_io, &mut buf) { - return result; + return match result { + Poll::Ready(Ok(_)) => { + let filled: usize = buf.filled().len(); + text_io + .handle + .speed_controller + .min_recv_speed_limit(filled) + .map_err(|_e| std::io::Error::from(std::io::ErrorKind::Other))?; + text_io + .handle + .speed_controller + .delay_max_recv_speed_limit(filled); + text_io.handle.speed_controller.reset_recv_pending_timeout(); + Poll::Ready(Ok(())) + } + Poll::Ready(Err(e)) => Poll::Ready(Err(e)), + Poll::Pending => Poll::Pending, + }; } } Poll::Ready(Ok(())) diff --git a/ylong_http_client/src/async_impl/http_body.rs b/ylong_http_client/src/async_impl/http_body.rs index 414e397..bef1026 100644 --- a/ylong_http_client/src/async_impl/http_body.rs +++ b/ylong_http_client/src/async_impl/http_body.rs @@ -65,7 +65,7 @@ const TRAILER_SIZE: usize = 1024; /// ``` pub struct HttpBody { kind: Kind, - sleep: Option>>, + timeout: Option>>, } type BoxStreamData = Box; @@ -92,11 +92,14 @@ impl HttpBody { #[cfg(feature = "http1_1")] BodyLength::Chunk => Kind::Chunk(Chunk::new(pre, io, interceptors)), }; - Ok(Self { kind, sleep: None }) + Ok(Self { + kind, + timeout: None, + }) } pub(crate) fn set_sleep(&mut self, sleep: Option>>) { - self.sleep = sleep; + self.timeout = sleep; } } @@ -112,7 +115,7 @@ impl Body for HttpBody { return Poll::Ready(Ok(0)); } - if let Some(delay) = self.sleep.as_mut() { + if let Some(delay) = self.timeout.as_mut() { if let Poll::Ready(()) = Pin::new(delay).poll(cx) { return Poll::Ready(err_from_io!(Timeout, std::io::ErrorKind::TimedOut.into())); } @@ -132,7 +135,7 @@ impl Body for HttpBody { cx: &mut Context<'_>, ) -> Poll, Self::Error>> { // Get trailer data from io - if let Some(delay) = self.sleep.as_mut() { + if let Some(delay) = self.timeout.as_mut() { if let Poll::Ready(()) = Pin::new(delay).poll(cx) { return Poll::Ready(err_from_msg!(Timeout, "Request timeout")); } diff --git a/ylong_http_client/src/async_impl/pool.rs b/ylong_http_client/src/async_impl/pool.rs index 1106c89..a8a5a7f 100644 --- a/ylong_http_client/src/async_impl/pool.rs +++ b/ylong_http_client/src/async_impl/pool.rs @@ -38,6 +38,7 @@ use crate::util::config::{HttpConfig, HttpVersion}; use crate::util::dispatcher::http1::{WrappedSemPermit, WrappedSemaphore}; use crate::util::dispatcher::{Conn, ConnDispatcher, Dispatcher, TimeInfoConn}; use crate::util::pool::{Pool, PoolKey}; +use crate::util::progress::SpeedConfig; #[cfg(feature = "http3")] use crate::util::request::RequestArc; use crate::util::ConnInfo; @@ -76,7 +77,12 @@ impl ConnPool { #[cfg(feature = "http3")] let alt_svc = self.alt_svcs.get_alt_svcs(&key); self.pool - .get(key, Conns::new, self.config.http1_config.max_conn_num()) + .get( + key, + Conns::new, + self.config.http1_config.max_conn_num(), + self.config.speed_config, + ) .conn( self.config.clone(), self.connector.clone(), @@ -99,6 +105,7 @@ pub(crate) enum H1ConnOption { } pub(crate) struct Conns { + speed_config: SpeedConfig, usable: WrappedSemaphore, list: Arc>>>, #[cfg(feature = "http2")] @@ -108,8 +115,9 @@ pub(crate) struct Conns { } impl Conns { - fn new(max_conn_num: usize) -> Self { + fn new(max_conn_num: usize, speed_config: SpeedConfig) -> Self { Self { + speed_config, usable: WrappedSemaphore::new(max_conn_num), list: Arc::new(Mutex::new(Vec::new())), @@ -128,6 +136,7 @@ impl Conns { impl Clone for Conns { fn clone(&self) -> Self { Self { + speed_config: self.speed_config, usable: self.usable.clone(), list: self.list.clone(), @@ -218,7 +227,7 @@ impl Conns // resulting in the creation of multiple tcp connections let mut lock = self.h2_conn.lock().await; - if let Some(conn) = Self::exist_h2_conn(&mut lock) { + if let Some(conn) = self.exist_h2_conn(&mut lock) { return Ok(TimeInfoConn::new(conn, TimeGroup::default())); } let stream = connector.connect(url, HttpVersion::Http2).await?; @@ -236,7 +245,7 @@ impl Conns _ => {} } let time_group = take(data.time_group_mut()); - let conn = Self::dispatch_h2_conn(data.detail(), config, stream, &mut lock); + let conn = self.dispatch_h2_conn(data.detail(), config, stream, &mut lock); Ok(TimeInfoConn::new(conn, time_group)) } @@ -252,7 +261,7 @@ impl Conns { let mut lock = self.h3_conn.lock().await; - if let Some(conn) = Self::exist_h3_conn(&mut lock) { + if let Some(conn) = self.exist_h3_conn(&mut lock) { return Ok(TimeInfoConn::new(conn, TimeGroup::default())); } let mut stream = connector.connect(url, HttpVersion::Http3).await?; @@ -265,7 +274,7 @@ impl Conns let mut data = stream.conn_data(); let time_group = take(data.time_group_mut()); Ok(TimeInfoConn::new( - Self::dispatch_h3_conn(data.detail(), config, stream, quic_conn, &mut lock), + self.dispatch_h3_conn(data.detail(), config, stream, quic_conn, &mut lock), time_group, )) } @@ -283,7 +292,7 @@ impl Conns match *url.scheme().unwrap() { Scheme::HTTPS => { let mut lock = self.h2_conn.lock().await; - if let Some(conn) = Self::exist_h2_conn(&mut lock) { + if let Some(conn) = self.exist_h2_conn(&mut lock) { return Ok(TimeInfoConn::new(conn, TimeGroup::default())); } let permit = self.usable.acquire().await; @@ -315,7 +324,7 @@ impl Conns )) } else if protocol == b"h2" { std::mem::drop(permit); - let conn = Self::dispatch_h2_conn(data.detail(), h2_config, stream, &mut lock); + let conn = self.dispatch_h2_conn(data.detail(), h2_config, stream, &mut lock); Ok(TimeInfoConn::new(conn, time_group)) } else { std::mem::drop(permit); @@ -338,7 +347,7 @@ impl Conns C: Connector, { let mut lock = self.h3_conn.lock().await; - if let Some(conn) = Self::exist_h3_conn(&mut lock) { + if let Some(conn) = self.exist_h3_conn(&mut lock) { return Some(TimeInfoConn::new(conn, TimeGroup::default())); } if let Some(alt_svcs) = alt_svcs { @@ -364,7 +373,7 @@ impl Conns let mut data = stream.conn_data(); let time_group = take(data.time_group_mut()); return Some(TimeInfoConn::new( - Self::dispatch_h3_conn( + self.dispatch_h3_conn( data.detail(), h3_config.clone(), stream, @@ -385,11 +394,13 @@ impl Conns list.push(dispatcher); #[cfg(any(feature = "http2", feature = "http3"))] if let Conn::Http1(ref mut h1) = conn { + h1.speed_controller.set_speed_limit(self.speed_config); h1.occupy_sem(permit) } #[cfg(all(not(feature = "http2"), not(feature = "http3")))] { let Conn::Http1(ref mut h1) = conn; + h1.speed_controller.set_speed_limit(self.speed_config); h1.occupy_sem(permit) } conn @@ -397,19 +408,24 @@ impl Conns #[cfg(feature = "http2")] fn dispatch_h2_conn( + &self, detail: ConnDetail, config: H2Config, stream: S, lock: &mut crate::runtime::MutexGuard>>, ) -> Conn { let dispatcher = ConnDispatcher::http2(detail, config, stream); - let conn = dispatcher.dispatch().unwrap(); + let mut conn = dispatcher.dispatch().unwrap(); lock.push(dispatcher); + if let Conn::Http2(ref mut h2) = conn { + h2.speed_controller.set_speed_limit(self.speed_config); + } conn } #[cfg(feature = "http3")] fn dispatch_h3_conn( + &self, detail: ConnDetail, config: H3Config, stream: S, @@ -417,8 +433,11 @@ impl Conns lock: &mut crate::runtime::MutexGuard>>, ) -> Conn { let dispatcher = ConnDispatcher::http3(detail, config, stream, quic_connection); - let conn = dispatcher.dispatch().unwrap(); + let mut conn = dispatcher.dispatch().unwrap(); lock.push(dispatcher); + if let Conn::Http3(ref mut h3) = conn { + h3.speed_controller.set_speed_limit(self.speed_config); + } conn } @@ -440,6 +459,7 @@ impl Conns match conn { Some(Conn::Http1(mut h1)) => { h1.occupy_sem(permit); + h1.speed_controller.set_speed_limit(self.speed_config); H1ConnOption::Some(Conn::Http1(h1)) } _ => H1ConnOption::None(permit), @@ -448,6 +468,7 @@ impl Conns #[cfg(feature = "http2")] fn exist_h2_conn( + &self, lock: &mut crate::runtime::MutexGuard>>, ) -> Option> { if let Some(dispatcher) = lock.pop() { @@ -455,9 +476,10 @@ impl Conns return None; } if !dispatcher.is_goaway() { - if let Some(conn) = dispatcher.dispatch() { + if let Some(Conn::Http2(mut h2)) = dispatcher.dispatch() { lock.push(dispatcher); - return Some(conn); + h2.speed_controller.set_speed_limit(self.speed_config); + return Some(Conn::Http2(h2)); } } lock.push(dispatcher); @@ -467,6 +489,7 @@ impl Conns #[cfg(feature = "http3")] fn exist_h3_conn( + &self, lock: &mut crate::runtime::MutexGuard>>, ) -> Option> { if let Some(dispatcher) = lock.pop() { @@ -474,9 +497,10 @@ impl Conns return None; } if !dispatcher.is_goaway() { - if let Some(conn) = dispatcher.dispatch() { + if let Some(Conn::Http3(mut h3)) = dispatcher.dispatch() { lock.push(dispatcher); - return Some(conn); + h3.speed_controller.set_speed_limit(self.speed_config); + return Some(Conn::Http3(h3)); } } // Not all requests have been processed yet diff --git a/ylong_http_client/src/lib.rs b/ylong_http_client/src/lib.rs index d1ec7b3..f1431e4 100644 --- a/ylong_http_client/src/lib.rs +++ b/ylong_http_client/src/lib.rs @@ -83,6 +83,7 @@ pub(crate) mod runtime { #[cfg(all(feature = "tokio_base", feature = "async"))] pub(crate) use tokio::{ io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf}, + macros::support::poll_fn, net::TcpStream, sync::{OwnedSemaphorePermit as SemaphorePermit, Semaphore}, task::{spawn_blocking, JoinHandle}, @@ -90,6 +91,7 @@ pub(crate) mod runtime { }; #[cfg(feature = "ylong_base")] pub(crate) use ylong_runtime::{ + futures::poll_fn, io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf}, net::TcpStream, spawn_blocking, diff --git a/ylong_http_client/src/util/c_openssl/ffi/stack.rs b/ylong_http_client/src/util/c_openssl/ffi/stack.rs index c5f9eb7..35dfdad 100644 --- a/ylong_http_client/src/util/c_openssl/ffi/stack.rs +++ b/ylong_http_client/src/util/c_openssl/ffi/stack.rs @@ -57,13 +57,14 @@ pub(crate) unsafe fn unified_sk_free(st: STACK) { } } -/// Retrieves a pointer to a stack element from a stack allocated by OpenSSL or BoringSSL at the specified index. +/// Retrieves a pointer to a stack element from a stack allocated by OpenSSL or +/// BoringSSL at the specified index. /// /// # Safety /// - `st` must be a valid pointer to a stack allocated by the same library /// (OpenSSL or BoringSSL) used in this crate. -/// - `idx` must be a valid index within the bounds of the stack. -/// if the index is out of range, the function will return `null`. +/// - `idx` must be a valid index within the bounds of the stack. if the index +/// is out of range, the function will return `null`. pub(crate) unsafe fn unified_sk_value(st: STACK, idx: c_int) -> *mut c_void { #[cfg(feature = "c_boringssl")] { @@ -99,8 +100,9 @@ pub(crate) unsafe fn unified_sk_num(st: STACK) -> c_int { /// # Safety /// - `st` must be a valid pointer to a stack allocated by the same library /// (OpenSSL or BoringSSL) used in this crate. -/// - The caller must check the return value of this function. If the stack is empty, -/// the function will return `null`. The caller must handle this case appropriately. +/// - The caller must check the return value of this function. If the stack is +/// empty, the function will return `null`. The caller must handle this case +/// appropriately. pub(crate) unsafe fn unified_sk_pop(st: STACK) -> *mut c_void { #[cfg(feature = "c_boringssl")] { diff --git a/ylong_http_client/src/util/c_openssl/ssl/stream.rs b/ylong_http_client/src/util/c_openssl/ssl/stream.rs index 14ed1b4..01a7dc7 100644 --- a/ylong_http_client/src/util/c_openssl/ssl/stream.rs +++ b/ylong_http_client/src/util/c_openssl/ssl/stream.rs @@ -26,7 +26,6 @@ use crate::c_openssl::error::ErrorStack; use crate::c_openssl::ffi::ssl::{SSL_connect, SSL_set_bio, SSL_shutdown}; use crate::c_openssl::foreign::Foreign; use crate::c_openssl::verify::PinsVerifyInfo; - use crate::util::base64::encode; use crate::util::c_openssl::bio::BioMethod; use crate::util::c_openssl::error::VerifyError; @@ -256,8 +255,9 @@ pub(crate) enum ShutdownResult { } pub(crate) fn verify_server_root_cert(ssl: *const SSL, pinned_key: &str) -> Result<(), SslError> { - use crate::c_openssl::ffi::{ssl::SSL_get_peer_cert_chain, ssl::X509_chain_up_ref}; - use crate::c_openssl::{stack::Stack, x509::X509}; + use crate::c_openssl::ffi::ssl::{SSL_get_peer_cert_chain, X509_chain_up_ref}; + use crate::c_openssl::stack::Stack; + use crate::c_openssl::x509::X509; let cert_chain = unsafe { X509_chain_up_ref(SSL_get_peer_cert_chain(ssl)) }; if cert_chain.is_null() { diff --git a/ylong_http_client/src/util/c_openssl/stack.rs b/ylong_http_client/src/util/c_openssl/stack.rs index 5640b4b..421c0ca 100644 --- a/ylong_http_client/src/util/c_openssl/stack.rs +++ b/ylong_http_client/src/util/c_openssl/stack.rs @@ -19,7 +19,6 @@ use core::ops::{Deref, DerefMut, Range}; use libc::c_int; use super::ffi::stack::{unified_sk_free, unified_sk_num, unified_sk_pop, unified_sk_value, STACK}; - use crate::c_openssl::foreign::{Foreign, ForeignRef, ForeignRefWrapper}; pub(crate) trait Stackof: Foreign { diff --git a/ylong_http_client/src/util/c_openssl/verify/pinning.rs b/ylong_http_client/src/util/c_openssl/verify/pinning.rs index d14bc32..2b4984a 100644 --- a/ylong_http_client/src/util/c_openssl/verify/pinning.rs +++ b/ylong_http_client/src/util/c_openssl/verify/pinning.rs @@ -134,9 +134,9 @@ impl PubKeyPinsBuilder { /// Sets a tuple of (server, public key digest) for `PubKeyPins`, using /// the root certificate pinning strategy. ///
- /// Ensure that the server returns the complete certificate chain, including the root certificate; - /// otherwise, the client's public key pinning validation will fail and return an error. - ///
+ /// Ensure that the server returns the complete certificate chain, including + /// the root certificate; otherwise, the client's public key pinning + /// validation will fail and return an error. /// /// # Examples /// diff --git a/ylong_http_client/src/util/config/http.rs b/ylong_http_client/src/util/config/http.rs index abc5f26..033b222 100644 --- a/ylong_http_client/src/util/config/http.rs +++ b/ylong_http_client/src/util/config/http.rs @@ -12,6 +12,7 @@ // limitations under the License. //! HTTP configure module. +use crate::util::progress::SpeedConfig; #[cfg(feature = "http3")] use crate::ErrorKind; @@ -19,6 +20,7 @@ use crate::ErrorKind; #[derive(Clone)] pub(crate) struct HttpConfig { pub(crate) version: HttpVersion, + pub(crate) speed_config: SpeedConfig, #[cfg(feature = "http1_1")] pub(crate) http1_config: http1::H1Config, @@ -35,13 +37,11 @@ impl HttpConfig { pub(crate) fn new() -> Self { Self { version: HttpVersion::Negotiate, - + speed_config: SpeedConfig::none(), #[cfg(feature = "http1_1")] http1_config: http1::H1Config::default(), - #[cfg(feature = "http2")] http2_config: http2::H2Config::new(), - #[cfg(feature = "http3")] http3_config: http3::H3Config::new(), } diff --git a/ylong_http_client/src/util/data_ref.rs b/ylong_http_client/src/util/data_ref.rs index 1d53e11..7d92b5f 100644 --- a/ylong_http_client/src/util/data_ref.rs +++ b/ylong_http_client/src/util/data_ref.rs @@ -17,15 +17,19 @@ use std::pin::Pin; use std::task::{Context, Poll}; use crate::runtime::{AsyncRead, ReadBuf}; +use crate::util::progress::SpeedController; use crate::util::request::RequestArc; +use crate::HttpClientError; pub(crate) struct BodyDataRef { + pub(crate) speed_controller: SpeedController, body: Option, } impl BodyDataRef { - pub(crate) fn new(request: RequestArc) -> Self { + pub(crate) fn new(request: RequestArc, speed_controller: SpeedController) -> Self { Self { + speed_controller, body: Some(request), } } @@ -38,18 +42,33 @@ impl BodyDataRef { &mut self, cx: &mut Context<'_>, buf: &mut [u8], - ) -> Poll> { + ) -> Poll> { let request = if let Some(ref mut request) = self.body { request } else { - return Poll::Ready(Some(0)); + return Poll::Ready(Ok(0)); }; + self.speed_controller.init_min_send_if_not_start(); + if self + .speed_controller + .poll_max_send_delay_time(cx) + .is_pending() + { + return Poll::Pending; + } + self.speed_controller.init_max_send_if_not_start(); let data = request.ref_mut().body_mut(); let mut read_buf = ReadBuf::new(buf); let data = Pin::new(data); match data.poll_read(cx, &mut read_buf) { - Poll::Ready(Err(_)) => Poll::Ready(None), - Poll::Ready(Ok(_)) => Poll::Ready(Some(read_buf.filled().len())), + Poll::Ready(Err(e)) => Poll::Ready(err_from_io!(BodyTransfer, e)), + Poll::Ready(Ok(_)) => { + let filled: usize = read_buf.filled().len(); + // Limit the write I/O speed by limiting the read file speed. + self.speed_controller.min_send_speed_limit(filled)?; + self.speed_controller.delay_max_send_speed_limit(filled); + Poll::Ready(Ok(filled)) + } Poll::Pending => Poll::Pending, } } diff --git a/ylong_http_client/src/util/dispatcher.rs b/ylong_http_client/src/util/dispatcher.rs index c675e95..cee41ef 100644 --- a/ylong_http_client/src/util/dispatcher.rs +++ b/ylong_http_client/src/util/dispatcher.rs @@ -136,6 +136,7 @@ pub(crate) mod http1 { use crate::runtime::Semaphore; #[cfg(feature = "tokio_base")] use crate::runtime::SemaphorePermit; + use crate::util::progress::SpeedController; impl ConnDispatcher { pub(crate) fn http1(io: S) -> Self { @@ -195,13 +196,18 @@ pub(crate) mod http1 { /// Handle returned to other threads for I/O operations. pub(crate) struct Http1Conn { + pub(crate) speed_controller: SpeedController, pub(crate) sem: Option, pub(crate) inner: Arc>, } impl Http1Conn { pub(crate) fn from_inner(inner: Arc>) -> Self { - Self { sem: None, inner } + Self { + speed_controller: SpeedController::none(), + sem: None, + inner, + } } pub(crate) fn occupy_sem(&mut self, sem: WrappedSemPermit) { @@ -305,8 +311,10 @@ pub(crate) mod http2 { ConnManager, FlowControl, H2StreamState, RecvData, RequestWrapper, SendData, StreamEndState, Streams, }; + use crate::util::progress::SpeedController; use crate::ErrorKind::Request; use crate::{ConnDetail, ErrorKind, HttpClientError}; + const DEFAULT_MAX_FRAME_SIZE: usize = 2 << 13; const DEFAULT_WINDOW_SIZE: u32 = 65535; @@ -349,6 +357,7 @@ pub(crate) mod http2 { } pub(crate) struct Http2Conn { + pub(crate) speed_controller: SpeedController, pub(crate) allow_cached_frames: usize, // Sends frame to StreamController pub(crate) sender: UnboundedSender, @@ -538,6 +547,7 @@ pub(crate) mod http2 { detail: ConnDetail, ) -> Self { Self { + speed_controller: SpeedController::none(), allow_cached_frames: allow_cached_num, sender, receiver: RespReceiver::default(), @@ -690,7 +700,7 @@ pub(crate) mod http2 { Some(ref mut receiver) => { #[cfg(feature = "tokio_base")] match receiver.recv().await { - None => err_from_msg!(Request, "Response Receiver Closed !"), + None => err_from_msg!(Request, "Response Sender Closed !"), Some(message) => match message { RespMessage::Output(frame) => Ok(frame), RespMessage::OutputExit(e) => Err(dispatch_client_error(e)), @@ -722,7 +732,7 @@ pub(crate) mod http2 { #[cfg(feature = "tokio_base")] match receiver.poll_recv(cx) { Poll::Ready(None) => { - Poll::Ready(err_from_msg!(Request, "Error receive response !")) + Poll::Ready(err_from_msg!(Request, "Response Sender Closed !")) } Poll::Ready(Some(message)) => match message { RespMessage::Output(frame) => Poll::Ready(Ok(frame)), @@ -819,6 +829,7 @@ pub(crate) mod http3 { use crate::util::dispatcher::{ConnDispatcher, Dispatcher}; use crate::util::h3::io_manager::IOManager; use crate::util::h3::stream_manager::StreamManager; + use crate::util::progress::SpeedController; use crate::ErrorKind::Request; use crate::{ConnDetail, ConnInfo, ErrorKind, HttpClientError}; @@ -832,6 +843,7 @@ pub(crate) mod http3 { } pub(crate) struct Http3Conn { + pub(crate) speed_controller: SpeedController, pub(crate) sender: UnboundedSender, pub(crate) resp_receiver: BoundedReceiver, pub(crate) resp_sender: BoundedSender, @@ -931,6 +943,7 @@ pub(crate) mod http3 { const CHANNEL_SIZE: usize = 3; let (resp_sender, resp_receiver) = bounded_channel(CHANNEL_SIZE); Self { + speed_controller: SpeedController::none(), sender, resp_sender, resp_receiver, @@ -957,7 +970,7 @@ pub(crate) mod http3 { pub(crate) async fn recv_resp(&mut self) -> Result { #[cfg(feature = "tokio_base")] match self.resp_receiver.recv().await { - None => err_from_msg!(Request, "Response Receiver Closed !"), + None => err_from_msg!(Request, "Response Sender Closed !"), Some(message) => match message { RespMessage::Output(frame) => Ok(frame), RespMessage::OutputExit(e) => Err(dispatch_client_error(e)), diff --git a/ylong_http_client/src/util/h2/manager.rs b/ylong_http_client/src/util/h2/manager.rs index be9b62f..42330fe 100644 --- a/ylong_http_client/src/util/h2/manager.rs +++ b/ylong_http_client/src/util/h2/manager.rs @@ -230,25 +230,36 @@ impl ConnManager { } loop { - match self.controller.streams.poll_read_body(cx, id)? { - DataReadState::Closed => { - break; - } - DataReadState::Pending => { - break; - } - DataReadState::Ready(data) => { - self.poll_send_frame(data)?; - } - DataReadState::Finish(frame) => { - self.poll_send_frame(frame)?; - break; - } + match self.controller.streams.poll_read_body(cx, id) { + Ok(state) => match state { + DataReadState::Closed => break, + DataReadState::Pending => break, + DataReadState::Ready(data) => self.poll_send_frame(data)?, + DataReadState::Finish(frame) => { + self.poll_send_frame(frame)?; + break; + } + }, + Err(e) => return self.deal_poll_body_error(cx, e), } } Ok(()) } + fn deal_poll_body_error( + &mut self, + cx: &mut Context<'_>, + e: H2Error, + ) -> Result<(), DispatchErrorKind> { + match e { + H2Error::StreamError(id, code) => match self.manage_stream_error(cx, id, code) { + Poll::Ready(res) => res, + Poll::Pending => Ok(()), + }, + H2Error::ConnectionError(e) => Err(H2Error::ConnectionError(e).into()), + } + } + fn poll_send_frame(&mut self, frame: Frame) -> Result<(), DispatchErrorKind> { match frame.payload() { Payload::Headers(_) => { @@ -273,7 +284,8 @@ impl ConnManager { } _ => {} } - + // TODO Replace with a bounded channel to avoid excessive local memory overhead + // when I/O is blocked in the process of uploading large files. self.input_tx .send(frame) .map_err(|_e| DispatchErrorKind::ChannelClosed) @@ -556,7 +568,7 @@ impl ConnManager { match self.controller.send_message_to_stream( cx, id, - RespMessage::OutputExit(DispatchErrorKind::ChannelClosed), + RespMessage::OutputExit(DispatchErrorKind::H2(H2Error::StreamError(id, code))), ) { Poll::Ready(_) => { // error at the stream level due to early exit of the coroutine in which the diff --git a/ylong_http_client/src/util/h2/streams.rs b/ylong_http_client/src/util/h2/streams.rs index b1a395e..192f1b0 100644 --- a/ylong_http_client/src/util/h2/streams.rs +++ b/ylong_http_client/src/util/h2/streams.rs @@ -462,7 +462,7 @@ impl Streams { return Err(H2Error::ConnectionError(ErrorCode::IntervalError)); }; match stream.data.poll_read(cx, buf) { - Poll::Ready(Some(size)) => { + Poll::Ready(Ok(size)) => { if size > 0 { stream.send_window.send_data(size as u32); self.flow_control.send_data(size as u32); @@ -485,7 +485,7 @@ impl Streams { ))) } } - Poll::Ready(None) => Err(H2Error::ConnectionError(ErrorCode::IntervalError)), + Poll::Ready(Err(_)) => Err(H2Error::StreamError(id, ErrorCode::IntervalError)), Poll::Pending => { self.push_back_pending_send(id); Ok(DataReadState::Pending) @@ -808,6 +808,7 @@ mod ut_streams { use super::*; use crate::async_impl::{Body, Request}; use crate::request::RequestArc; + use crate::util::progress::SpeedController; fn stream_new(state: H2StreamState) -> Stream { Stream { @@ -815,9 +816,10 @@ mod ut_streams { recv_window: RecvWindow::new(100), state, header: None, - data: BodyDataRef::new(RequestArc::new( - Request::builder().body(Body::empty()).unwrap(), - )), + data: BodyDataRef::new( + RequestArc::new(Request::builder().body(Body::empty()).unwrap()), + SpeedController::none(), + ), } } @@ -885,7 +887,8 @@ mod ut_streams { /// # Brief /// 1. Insert streams with different states and sends go_away with a stream /// id. - /// 2. Asserts that only streams with IDs greater than to the go_away ID are closed. + /// 2. Asserts that only streams with IDs greater than to the go_away ID are + /// closed. #[test] fn ut_streams_get_unset_streams() { let mut streams = Streams::new(100, 100, FlowControl::new(100, 100)); diff --git a/ylong_http_client/src/util/h3/streams.rs b/ylong_http_client/src/util/h3/streams.rs index 6345d67..2fc60f4 100644 --- a/ylong_http_client/src/util/h3/streams.rs +++ b/ylong_http_client/src/util/h3/streams.rs @@ -485,7 +485,7 @@ impl Streams { } match stream.data.poll_read(cx, buf) { - Poll::Ready(Some(size)) => { + Poll::Ready(Ok(size)) => { if size > 0 { let data_vec = Vec::from(&buf[..size]); Ok(DataReadState::Ready(Box::new(Frame::new( @@ -496,7 +496,7 @@ impl Streams { Ok(DataReadState::Finish) } } - Poll::Ready(None) => Err(DispatchErrorKind::H3(H3Error::Connection( + Poll::Ready(Err(_)) => Err(DispatchErrorKind::H3(H3Error::Connection( H3ErrorCode::H3InternalError, ))), Poll::Pending => { diff --git a/ylong_http_client/src/util/mod.rs b/ylong_http_client/src/util/mod.rs index 700aa03..576b183 100644 --- a/ylong_http_client/src/util/mod.rs +++ b/ylong_http_client/src/util/mod.rs @@ -46,6 +46,7 @@ pub(crate) mod h3; pub(crate) mod information; pub(crate) mod interceptor; pub(crate) mod monitor; +pub(crate) mod progress; #[cfg(all(test, feature = "ylong_base"))] pub(crate) mod test_utils; diff --git a/ylong_http_client/src/util/pool.rs b/ylong_http_client/src/util/pool.rs index d8c0682..c4c0b6b 100644 --- a/ylong_http_client/src/util/pool.rs +++ b/ylong_http_client/src/util/pool.rs @@ -20,6 +20,8 @@ use std::sync::{Arc, Mutex}; use ylong_http::request::uri::{Authority, Scheme}; +use crate::util::progress::SpeedConfig; + pub(crate) struct Pool { pool: Arc>>, } @@ -33,14 +35,20 @@ impl Pool { } impl Pool { - pub(crate) fn get(&self, key: K, create_fn: F, allowed_num: usize) -> V + pub(crate) fn get( + &self, + key: K, + create_fn: F, + allowed_num: usize, + speed_conf: SpeedConfig, + ) -> V where - F: FnOnce(usize) -> V, + F: FnOnce(usize, SpeedConfig) -> V, { let mut inner = self.pool.lock().unwrap(); match (*inner).entry(key) { Entry::Occupied(conns) => conns.get().clone(), - Entry::Vacant(e) => e.insert(create_fn(allowed_num)).clone(), + Entry::Vacant(e) => e.insert(create_fn(allowed_num, speed_conf)).clone(), } } } @@ -59,6 +67,7 @@ mod ut_pool { use ylong_http::request::uri::Uri; use crate::pool::{Pool, PoolKey}; + use crate::util::progress::SpeedConfig; /// UT test cases for `Pool::get`. /// @@ -74,9 +83,9 @@ mod ut_pool { uri.authority().unwrap().clone(), ); let data = String::from("Data info"); - let consume_and_return_data = move |_size: usize| data; + let consume_and_return_data = move |_size: usize, _conf: SpeedConfig| data; let pool = Pool::new(); - let res = pool.get(key, consume_and_return_data, 6); + let res = pool.get(key, consume_and_return_data, 6, SpeedConfig::none()); assert_eq!(res, "Data info".to_string()); } } diff --git a/ylong_http_client/src/util/progress/mod.rs b/ylong_http_client/src/util/progress/mod.rs new file mode 100644 index 0000000..5c5954d --- /dev/null +++ b/ylong_http_client/src/util/progress/mod.rs @@ -0,0 +1,16 @@ +// Copyright (c) 2024 Huawei Device Co., Ltd. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +mod rate_limit; + +pub(crate) use rate_limit::{SpeedConfig, SpeedController}; diff --git a/ylong_http_client/src/util/progress/rate_limit.rs b/ylong_http_client/src/util/progress/rate_limit.rs new file mode 100644 index 0000000..09d8475 --- /dev/null +++ b/ylong_http_client/src/util/progress/rate_limit.rs @@ -0,0 +1,441 @@ +// Copyright (c) 2024 Huawei Device Co., Ltd. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::time::{Duration, Instant}; + +use crate::runtime::{sleep, Sleep}; +use crate::HttpClientError; + +pub(crate) const SPEED_CHECK_PERIOD: Duration = Duration::from_millis(1000); + +#[derive(Default, Clone)] +pub(crate) struct SpeedController { + pub(crate) send_rate_limit: RateLimit, + pub(crate) recv_rate_limit: RateLimit, +} + +impl SpeedController { + pub(crate) fn none() -> Self { + SpeedController::default() + } + + pub(crate) fn set_speed_limit(&mut self, config: SpeedConfig) { + if let Some(speed) = config.max_recv_speed() { + self.recv_rate_limit + .set_max_speed(speed, SPEED_CHECK_PERIOD); + } + + if let Some(speed) = config.min_recv_speed() { + if let Some(interval) = config.min_speed_interval() { + self.recv_rate_limit.set_min_speed( + speed, + SPEED_CHECK_PERIOD, + Duration::from_secs(interval), + ); + } + } + + if let Some(speed) = config.max_send_speed() { + self.send_rate_limit + .set_max_speed(speed, SPEED_CHECK_PERIOD); + } + + if let Some(speed) = config.min_send_speed() { + if let Some(interval) = config.min_speed_interval() { + self.send_rate_limit.set_min_speed( + speed, + SPEED_CHECK_PERIOD, + Duration::from_secs(interval), + ); + } + } + } + + pub(crate) fn need_limit_max_send_speed(&self) -> bool { + self.send_rate_limit.need_limit_max_speed() + } + + pub(crate) async fn max_send_speed_limit(&mut self, size: usize) { + self.send_rate_limit.max_speed_limit(size).await + } + + pub(crate) fn delay_max_recv_speed_limit(&mut self, size: usize) { + self.recv_rate_limit.delay_max_speed_limit(size) + } + + #[cfg(any(feature = "http2", feature = "http3"))] + pub(crate) fn delay_max_send_speed_limit(&mut self, size: usize) { + self.send_rate_limit.delay_max_speed_limit(size) + } + + pub(crate) fn min_send_speed_limit(&mut self, size: usize) -> Result<(), HttpClientError> { + self.send_rate_limit.min_speed_limit(size) + } + + pub(crate) fn reset_send_pending_timeout(&mut self) { + self.send_rate_limit.reset_pending_timeout() + } + + pub(crate) fn min_recv_speed_limit(&mut self, size: usize) -> Result<(), HttpClientError> { + self.recv_rate_limit.min_speed_limit(size) + } + + pub(crate) fn reset_recv_pending_timeout(&mut self) { + self.recv_rate_limit.reset_pending_timeout() + } + + pub(crate) fn poll_max_recv_delay_time(&mut self, cx: &mut Context<'_>) -> Poll<()> { + self.recv_rate_limit.poll_limited_delay(cx) + } + + pub(crate) fn poll_recv_pending_timeout(&mut self, cx: &mut Context<'_>) -> bool { + self.recv_rate_limit.poll_pending_timeout(cx) + } + + pub(crate) fn poll_send_pending_timeout(&mut self, cx: &mut Context<'_>) -> bool { + self.send_rate_limit.poll_pending_timeout(cx) + } + + #[cfg(any(feature = "http2", feature = "http3"))] + pub(crate) fn poll_max_send_delay_time(&mut self, cx: &mut Context<'_>) -> Poll<()> { + self.send_rate_limit.poll_limited_delay(cx) + } + + pub(crate) fn init_max_send_if_not_start(&mut self) { + self.send_rate_limit.init_max_limit_if_not_start(); + } + + pub(crate) fn init_min_send_if_not_start(&mut self) { + self.send_rate_limit.init_min_limit_if_not_start(); + } + + pub(crate) fn init_max_recv_if_not_start(&mut self) { + self.recv_rate_limit.init_max_limit_if_not_start(); + } + + pub(crate) fn init_min_recv_if_not_start(&mut self) { + self.recv_rate_limit.init_min_limit_if_not_start(); + } +} + +#[derive(Default, Clone)] +pub(crate) struct RateLimit { + min_speed: Option, + max_speed: Option, +} + +impl RateLimit { + pub(crate) fn set_min_speed(&mut self, rate: u64, period: Duration, interval: Duration) { + let limit = SpeedLimit::new(rate, period, interval); + self.min_speed = Some(limit) + } + + pub(crate) fn set_max_speed(&mut self, rate: u64, period: Duration) { + let limit = SpeedLimit::new(rate, period, Duration::default()); + self.max_speed = Some(limit) + } + + pub(crate) fn need_limit_max_speed(&self) -> bool { + self.max_speed.is_some() + } + + pub(crate) fn init_max_limit_if_not_start(&mut self) { + if let Some(ref mut speed) = self.max_speed { + speed.init_if_not_start() + } + } + + pub(crate) fn init_min_limit_if_not_start(&mut self) { + if let Some(ref mut speed) = self.min_speed { + speed.init_if_not_start() + } + } + + pub(crate) async fn max_speed_limit(&mut self, read: usize) { + if let Some(ref mut speed) = self.max_speed { + speed.limit_max_speed(read).await + } + } + + pub(crate) fn delay_max_speed_limit(&mut self, read: usize) { + if let Some(ref mut speed) = self.max_speed { + speed.delay_max_speed_limit(read) + } + } + + pub(crate) fn min_speed_limit(&mut self, read: usize) -> Result<(), HttpClientError> { + if let Some(ref mut speed) = self.min_speed { + speed.limit_min_speed(read) + } else { + Ok(()) + } + } + + pub(crate) fn reset_pending_timeout(&mut self) { + if let Some(ref mut speed) = self.min_speed { + speed.reset_pending_timeout() + } + } + + pub(crate) fn poll_pending_timeout(&mut self, cx: &mut Context<'_>) -> bool { + self.min_speed + .as_mut() + .is_some_and(|speed| speed.poll_pending_timeout(cx)) + } + + pub(crate) fn poll_limited_delay(&mut self, cx: &mut Context<'_>) -> Poll<()> { + if let Some(ref mut speed) = self.max_speed { + return speed.poll_max_limited_delay(cx); + } + Poll::Ready(()) + } +} + +#[derive(Default)] +pub(crate) struct SpeedLimit { + rate: u64, + // Speed limiting period, millisecond. + period: Duration, + min_speed_interval: Duration, + // min_speed_interval start time. + min_speed_start: Option, + // Data received within a period, byte. + period_data: u64, + // The elapsed time in the period. + elapsed_time: Duration, + // The maximum data allowed within a period, byte. + max_speed_allowed_bytes: u64, + // The start time of each io read or write. + start: Option, + // The time delay required to trigger the maximum speed limit. + delay: Option>>, + // min_speed_interval Pending Timeout time. + timeout: Option>>, +} + +impl SpeedLimit { + /// Creates a new `SpeedLimit`. + /// `rate` is the download size allowed within a period, expressed in + /// bytes/second. + pub(crate) fn new(rate: u64, period: Duration, interval: Duration) -> SpeedLimit { + SpeedLimit { + rate, + period, + min_speed_interval: interval, + min_speed_start: None, + period_data: 0, + elapsed_time: Duration::default(), + max_speed_allowed_bytes: rate * period.as_secs(), + start: None, + delay: None, + timeout: Some(Box::pin(sleep(interval))), + } + } + + pub(crate) fn init_if_not_start(&mut self) { + self.start.get_or_insert(Instant::now()); + } + + pub(crate) fn poll_pending_timeout(&mut self, cx: &mut Context<'_>) -> bool { + self.timeout + .as_mut() + .is_some_and(|timeout| Pin::new(timeout).poll(cx).is_ready()) + } + + pub(crate) fn poll_max_limited_delay(&mut self, cx: &mut Context<'_>) -> Poll<()> { + if let Some(delay) = self.delay.as_mut() { + return match Pin::new(delay).poll(cx) { + Poll::Ready(()) => { + self.delay = None; + self.next_period(); + Poll::Ready(()) + } + Poll::Pending => Poll::Pending, + }; + } + Poll::Ready(()) + } + + pub(crate) fn delay_max_speed_limit(&mut self, data_size: usize) { + if let Some(start_time) = self.start.take() { + self.elapsed_time += start_time.elapsed(); + self.period_data += data_size as u64; + if self.elapsed_time < self.period { + if self.period_data >= self.max_speed_allowed_bytes { + // The minimum milliseconds to download this data within the speed limit. + let limited_time = Duration::from_millis(self.period_data * 1000 / self.rate); + // We will not poll here immediately because the data has not yet been returned + // to user. + self.delay = Some(Box::pin(sleep(limited_time - self.elapsed_time))); + } + } else { + // The minimum milliseconds to download this data within the speed limit. + let limited_time = Duration::from_millis(self.period_data * 1000 / self.rate); + if self.elapsed_time < limited_time { + // We will not poll here immediately because the data has not yet been returned + // to user. + self.delay = Some(Box::pin(sleep(limited_time - self.elapsed_time))); + } else { + // We don't count the part that goes beyond the period, and we go straight to + // the next period. + self.next_period() + } + } + } + } + + pub(crate) async fn limit_max_speed(&mut self, data_size: usize) { + if let Some(start_time) = self.start.take() { + // let elapsed_total = start_time.elapsed(); + self.elapsed_time += start_time.elapsed(); + self.period_data += data_size as u64; + if self.elapsed_time < self.period { + if self.period_data >= self.max_speed_allowed_bytes { + // The minimum milliseconds to download this data within the speed limit. + let limited_time = Duration::from_millis(self.period_data * 1000 / self.rate); + sleep(limited_time - self.elapsed_time).await; + self.next_period(); + } + } else { + // The minimum milliseconds to download this data within the speed limit. + let limited_time = Duration::from_millis(self.period_data * 1000 / self.rate); + if self.elapsed_time < limited_time { + sleep(limited_time - self.elapsed_time).await; + } + // We don't count the part that goes beyond the period, and we go straight to + // the next period. + self.next_period() + } + } + } + + pub(crate) fn limit_min_speed(&mut self, data_size: usize) -> Result<(), HttpClientError> { + if let Some(start_time) = self.start.take() { + self.min_speed_start.get_or_insert(start_time); + self.elapsed_time += start_time.elapsed(); + if self.elapsed_time >= self.period { + self.check_min_speed(data_size)?; + } else { + self.period_data += data_size as u64; + } + } + Ok(()) + } + + pub(crate) fn reset_pending_timeout(&mut self) { + self.timeout = Some(Box::pin(sleep(self.min_speed_interval))); + } + + fn check_min_speed(&mut self, data_size: usize) -> Result<(), HttpClientError> { + self.period_data += data_size as u64; + // The time it takes to process period_data at the minimum speed limit. + let limited_time = Duration::from_millis(self.period_data * 1000 / self.rate); + if self.elapsed_time > limited_time { + // self.min_speed_start must be Some because it was assigned before this + // function was called. + if let Some(ref check_start) = self.min_speed_start { + let check_elapsed = check_start.elapsed(); + // If the time at min_speed_limit exceeds min_speed_interval, an error is + // raised. + if check_elapsed > self.min_speed_interval { + return err_from_msg!(BodyTransfer, "Below low speed limit"); + } + } + } else { + // If the speed exceeds min_speed_limit, min_speed_interval is reset + // immediately. + self.next_interval(); + } + self.next_period(); + Ok(()) + } + + fn next_period(&mut self) { + self.period_data = 0; + self.start = None; + self.elapsed_time = Duration::default(); + } + + fn next_interval(&mut self) { + self.min_speed_start = None + } +} + +impl Clone for SpeedLimit { + fn clone(&self) -> Self { + Self { + rate: self.rate, + period: self.period, + min_speed_interval: self.min_speed_interval, + min_speed_start: None, + period_data: self.period_data, + elapsed_time: self.elapsed_time, + max_speed_allowed_bytes: self.max_speed_allowed_bytes, + start: None, + delay: None, + timeout: None, + } + } +} + +#[derive(Default, Copy, Clone)] +pub(crate) struct SpeedConfig { + max_recv: Option, + min_recv: Option, + max_send: Option, + min_send: Option, + min_speed_interval: Option, +} + +impl SpeedConfig { + pub(crate) fn none() -> SpeedConfig { + Self::default() + } + + pub(crate) fn set_max_rate(&mut self, rate: u64) { + self.max_recv = Some(rate); + self.max_send = Some(rate) + } + + pub(crate) fn set_min_rate(&mut self, rate: u64) { + self.min_send = Some(rate); + self.min_recv = Some(rate) + } + + pub(crate) fn set_min_speed_interval(&mut self, seconds: u64) { + self.min_speed_interval = Some(seconds) + } + + pub(crate) fn max_recv_speed(&self) -> Option { + self.max_recv + } + + pub(crate) fn max_send_speed(&self) -> Option { + self.max_send + } + + pub(crate) fn min_recv_speed(&self) -> Option { + self.min_recv + } + + pub(crate) fn min_send_speed(&self) -> Option { + self.min_send + } + + pub(crate) fn min_speed_interval(&self) -> Option { + self.min_speed_interval + } +} diff --git a/ylong_http_client/tests/sdv_async_https_pinning.rs b/ylong_http_client/tests/sdv_async_https_pinning.rs index 963d258..974a54f 100644 --- a/ylong_http_client/tests/sdv_async_https_pinning.rs +++ b/ylong_http_client/tests/sdv_async_https_pinning.rs @@ -397,7 +397,8 @@ fn sdv_client_public_key_root_pinning() { .expect("Runtime block on server shutdown failed"); } - // Root certificate pinning strategy, but using the server certificate public key hash + // Root certificate pinning strategy, but using the server certificate public + // key hash. { start_server!( HTTPS; -- Gitee