From ce6cd075825cb6d567cd31799e874b4b36195abf Mon Sep 17 00:00:00 2001 From: SurName <472256532@qq.com> Date: Mon, 2 Dec 2024 16:21:38 +0800 Subject: [PATCH] DNS cache reuse and DOH Signed-off-by: xujunyang --- ylong_http_client/Cargo.toml | 5 + ylong_http_client/examples/async_http_doh.rs | 93 ++ .../src/async_impl/connector/mod.rs | 22 +- ylong_http_client/src/async_impl/dns/mod.rs | 4 +- .../src/async_impl/dns/resolver.rs | 921 +++++++++++++----- ylong_http_client/src/async_impl/mod.rs | 4 +- ylong_http_client/src/lib.rs | 3 +- 7 files changed, 817 insertions(+), 235 deletions(-) create mode 100644 ylong_http_client/examples/async_http_doh.rs diff --git a/ylong_http_client/Cargo.toml b/ylong_http_client/Cargo.toml index 108d7f7..d75aea7 100644 --- a/ylong_http_client/Cargo.toml +++ b/ylong_http_client/Cargo.toml @@ -55,6 +55,11 @@ name = "async_http" path = "examples/async_http.rs" required-features = ["async", "http1_1", "ylong_base"] +[[example]] +name = "async_http_doh" +path = "examples/async_http_doh.rs" +required-features = ["async", "http1_1", "ylong_base", "__tls"] + [[example]] name = "async_http_multi" path = "examples/async_http_multi.rs" diff --git a/ylong_http_client/examples/async_http_doh.rs b/ylong_http_client/examples/async_http_doh.rs new file mode 100644 index 0000000..94a76a0 --- /dev/null +++ b/ylong_http_client/examples/async_http_doh.rs @@ -0,0 +1,93 @@ +// Copyright (c) 2023 Huawei Device Co., Ltd. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! This is a simple asynchronous HTTP client example using the +//! ylong_http_client crate. It demonstrates creating a client, making a +//! request, and reading the response asynchronously. + +use ylong_http_client::async_impl::{ + Body, Client, DefaultDnsResolver, DnsResolver, DohResolver, Downloader, Request, +}; +use ylong_http_client::HttpClientError; + +fn main() -> Result<(), HttpClientError> { + let handle = ylong_runtime::spawn(async move { + connect().await.unwrap(); + }); + + let _ = ylong_runtime::block_on(handle); + Ok(()) +} + +async fn connect() -> Result<(), HttpClientError> { + // Creates a `Default Dns Resolver`. + let default_dns_resolver = DefaultDnsResolver::new(std::time::Duration::from_secs(60)); + + // Creates a `Dns Resolver` + let dns_resolver = DnsResolver::default().add_dns_server("119.29.29.29:53"); + + // Creates a `DoH Resolver` + let doh_resolver = DohResolver::default() + .add_doh_server("https://1.12.12.12/dns-query") + .add_doh_server("https://120.53.53.53/dns-query"); + + // Creates a `async_impl::Client` + let default_dns_client = Client::builder() + .dns_resolver(default_dns_resolver) + .build() + .unwrap(); + + let dns_client = Client::builder() + .dns_resolver(dns_resolver) + .build() + .unwrap(); + + let doh_client = Client::builder() + .dns_resolver(doh_resolver) + .build() + .unwrap(); + + // Sends request and receives a `Response`. + let doh_response = doh_client + .request( + Request::builder() + .url("https://www.example.com") + .body(Body::empty())?, + ) + .await?; + + let dns_response = dns_client + .request( + Request::builder() + .url("https://www.example.com") + .body(Body::empty())?, + ) + .await?; + + let default_dns_response = default_dns_client + .request( + Request::builder() + .url("https://www.example.com") + .body(Body::empty())?, + ) + .await?; + + // Reads the body of `Response` by using `BodyReader`. + let _ = Downloader::console(doh_response).download().await; + + let _ = Downloader::console(dns_response).download().await; + + let _ = Downloader::console(default_dns_response).download().await; + + Ok(()) +} diff --git a/ylong_http_client/src/async_impl/connector/mod.rs b/ylong_http_client/src/async_impl/connector/mod.rs index b41727c..16ad412 100644 --- a/ylong_http_client/src/async_impl/connector/mod.rs +++ b/ylong_http_client/src/async_impl/connector/mod.rs @@ -18,6 +18,7 @@ mod stream; use core::future::Future; use std::io::{Error, ErrorKind}; use std::net::SocketAddr; +use std::str::FromStr; use std::sync::Arc; use ylong_http::request::uri::Uri; @@ -111,14 +112,19 @@ async fn dns_query( resolver: Arc, addr: &str, ) -> Result, HttpClientError> { - let addr_fut = resolver.resolve(addr); - let socket_addr = addr_fut.await.map_err(|e| { - HttpClientError::from_dns_error( - crate::ErrorKind::Connect, - Error::new(ErrorKind::Interrupted, e), - ) - })?; - Ok(socket_addr.collect::>()) + match SocketAddr::from_str(addr) { + Ok(socket_addr) => Ok(vec![socket_addr]), + Err(_) => { + let addr_fut = resolver.resolve(addr); + let socket_addr = addr_fut.await.map_err(|e| { + HttpClientError::from_dns_error( + crate::ErrorKind::Connect, + Error::new(ErrorKind::Interrupted, e), + ) + })?; + Ok(socket_addr.collect::>()) + } + } } async fn eyeballs_connect( diff --git a/ylong_http_client/src/async_impl/dns/mod.rs b/ylong_http_client/src/async_impl/dns/mod.rs index 2b8091d..dd5a05c 100644 --- a/ylong_http_client/src/async_impl/dns/mod.rs +++ b/ylong_http_client/src/async_impl/dns/mod.rs @@ -25,4 +25,6 @@ mod happy_eyeballs; mod resolver; pub(crate) use happy_eyeballs::{EyeBallConfig, HappyEyeballs}; -pub use resolver::{Addrs, DefaultDnsResolver, Resolver, SocketFuture, StdError}; +pub use resolver::{ + Addrs, DefaultDnsResolver, DnsResolver, DohResolver, Resolver, SocketFuture, StdError, +}; diff --git a/ylong_http_client/src/async_impl/dns/resolver.rs b/ylong_http_client/src/async_impl/dns/resolver.rs index ab021bb..76d40d6 100644 --- a/ylong_http_client/src/async_impl/dns/resolver.rs +++ b/ylong_http_client/src/async_impl/dns/resolver.rs @@ -17,17 +17,22 @@ use std::collections::HashMap; use std::future::Future; use std::io; use std::io::Error; -use std::net::{SocketAddr, ToSocketAddrs}; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, ToSocketAddrs, UdpSocket}; use std::pin::Pin; -use std::sync::{Arc, Mutex}; +use std::str::FromStr; +use std::sync::{Arc, Mutex, OnceLock}; use std::task::{Context, Poll}; use std::time::{Duration, Instant}; use std::vec::IntoIter; +use crate::async_impl::{Body, Client, Request}; use crate::runtime::JoinHandle; - +use crate::{ErrorKind, HttpClientError}; const DEFAULT_TTL: Duration = Duration::from_secs(60); -const MAX_ENTRIES_LEN: usize = 30000; +const DEFAULT_DOH_SERVER: &str = "https://1.12.12.12/dns-query"; // use tencent doh server +const DEFAULT_DNS_SERVER: &str = "119.29.29.29:53"; +const DEFAULT_MAX_LEN: usize = 30000; +const DEFAULT_MAX_RETRY_COUNT: i32 = 1; /// `SocketAddr` resolved by `Resolver`. pub type Addrs = Box + Sync + Send>; @@ -94,68 +99,39 @@ impl Future for DefaultDnsFuture { } } -/// Default dns resolver used by the `Client`. -/// DefaultDnsResolver provides DNS resolver with caching machanism. -pub struct DefaultDnsResolver { - manager: DnsManager, // Manages DNS cache - connector: DnsConnector, // Performing DNS resolution - ttl: Duration, // Time-to-live for the DNS cache -} - -impl Default for DefaultDnsResolver { - // Default constructor for `DefaultDnsResolver`, with a default TTL of 60 - // seconds. - fn default() -> Self { - DefaultDnsResolver { - manager: DnsManager::default(), - connector: DnsConnector {}, - ttl: DEFAULT_TTL, // Default TTL set to 60 seconds - } - } -} - -impl DefaultDnsResolver { - /// Create a new DefaultDnsResolver. And TTL is Time to live for cache. - /// - /// # Examples - /// - /// ``` - /// use std::time::Duration; - /// - /// use ylong_http_client::async_impl::DefaultDnsResolver; - /// - /// let res = DefaultDnsResolver::new(Duration::from_secs(1)); - /// ``` - pub fn new(ttl: Duration) -> Self { - DefaultDnsResolver { - manager: DnsManager::new(), - connector: DnsConnector {}, - ttl, // Set TTL through the passed parameters - } - } -} - -#[derive(Default)] +#[derive(Clone)] struct DnsManager { // Cache storing authority and DNS results - map: Mutex>, + map: Arc>>, + max_entries_len: usize, } impl DnsManager { // Creates a new `DnsManager` instance with an empty cache fn new() -> Self { DnsManager { - map: Mutex::new(HashMap::new()), + map: Arc::new(Mutex::new(HashMap::new())), + max_entries_len: DEFAULT_MAX_LEN, } } // Cleans expired DNS cache entries by retaining only valid ones fn clean_expired_entries(&self) { let mut map_lock = self.map.lock().unwrap(); - if map_lock.len() > MAX_ENTRIES_LEN { - map_lock.retain(|_, result| result.inner.lock().unwrap().is_valid()); + if map_lock.len() > self.max_entries_len { + map_lock.retain(|_, result| { + let inner = result.inner.lock().unwrap(); + inner.is_valid() + }); } } + + fn get_dns_manager() -> Arc> { + static DNS_MANAGER: OnceLock>> = OnceLock::new(); + DNS_MANAGER + .get_or_init(|| Arc::new(Mutex::new(DnsManager::new()))) + .clone() + } } struct DnsResult { @@ -174,7 +150,6 @@ impl DnsResult { } } -#[derive(Clone)] struct DnsResultInner { addr: Vec, // List of resolved addresses for the authority expiration_time: Instant, // Expiration time for the cache entry @@ -193,14 +168,52 @@ impl Default for DnsResultInner { fn default() -> Self { DnsResultInner { addr: vec![], - expiration_time: Instant::now() + Duration::from_secs(60), + expiration_time: Instant::now() + DEFAULT_TTL, } } } -struct DnsConnector {} +/// Default dns resolver used by the `Client`. +pub struct DefaultDnsResolver { + connector: DefaultDnsConnector, // Performing DNS resolution + ttl: Duration, // Time-to-live for the DNS cache +} -impl DnsConnector { +impl Default for DefaultDnsResolver { + // Default constructor for `DefaultDnsResolver`, with a default TTL of 60 + // seconds. + fn default() -> Self { + DefaultDnsResolver { + connector: DefaultDnsConnector {}, + ttl: DEFAULT_TTL, // Default TTL set to 60 seconds + } + } +} + +impl DefaultDnsResolver { + /// Create a new DefaultDnsResolver. And TTL is Time to live for cache. + /// + /// # Examples + /// + /// ``` + /// use std::time::Duration; + /// + /// use ylong_http_client::async_impl::DefaultDnsResolver; + /// + /// let res = DefaultDnsResolver::new(Duration::from_secs(1)); + /// ``` + pub fn new(ttl: Duration) -> Self { + DefaultDnsResolver { + connector: DefaultDnsConnector {}, + ttl, // Set TTL through the passed parameters + } + } +} + +#[derive(Clone)] +struct DefaultDnsConnector {} + +impl DefaultDnsConnector { // Resolves the authority to a list of socket addresses fn get_socket_addrs(&self, authority: &str) -> Result, io::Error> { authority @@ -213,222 +226,682 @@ impl DnsConnector { impl Resolver for DefaultDnsResolver { fn resolve(&self, authority: &str) -> SocketFuture { let authority = authority.to_string(); - self.manager.clean_expired_entries(); - Box::pin(async move { - let mut map_lock = self.manager.map.lock().unwrap(); - if let Some(addrs) = map_lock.get(&authority) { + let ttl = self.ttl; + let connector = self.connector.clone(); + let blocking = crate::runtime::spawn_blocking(move || { + let get_dns_manager = DnsManager::get_dns_manager(); + let map_lock = get_dns_manager.lock().unwrap(); + map_lock.clean_expired_entries(); + if let Some(addrs) = map_lock.map.lock().unwrap().get(&authority) { let lock_inner = addrs.inner.lock().unwrap(); if lock_inner.is_valid() { - return Ok(Box::new(lock_inner.addr.clone().into_iter()) as Addrs); + return Ok(ResolvedAddrs::new(lock_inner.addr.clone().into_iter())); } } - match self.connector.get_socket_addrs(&authority) { + match connector.get_socket_addrs(&authority) { Ok(addrs) => { - let dns_result = DnsResult::new(addrs.clone(), Instant::now() + self.ttl); - map_lock.insert(authority, dns_result); - Ok(Box::new(addrs.into_iter()) as Addrs) + let dns_result = DnsResult::new(addrs.clone(), Instant::now() + ttl); + map_lock.map.lock().unwrap().insert(authority, dns_result); + Ok(ResolvedAddrs::new(addrs.into_iter())) } - Err(err) => Err(Box::new(err) as StdError), + Err(err) => Err(io::Error::new(io::ErrorKind::Other, err)), } - }) + }); + + Box::pin(DefaultDnsFuture { inner: blocking }) + } +} + +#[derive(Default)] +pub struct DnsResolver { + connector: DnsConnector, +} + +impl DnsResolver { + pub fn add_dns_server(mut self, dns_server: &str) -> Self { + self.connector.dns_servers.push(dns_server.to_string()); + self + } +} + +#[derive(Clone)] +struct DnsConnector { + dns_servers: Vec, + max_retry_count: i32, +} + +impl Default for DnsConnector { + fn default() -> Self { + DnsConnector { + dns_servers: vec![DEFAULT_DNS_SERVER.to_string()], + max_retry_count: DEFAULT_MAX_RETRY_COUNT, + } + } +} + +impl DnsConnector { + fn retry(&self, authority: &str) -> Result<(Vec, u64), HttpClientError> { + for _i in 0..self.max_retry_count { + for server in &self.dns_servers { + if let Ok((socket_addr, ttl)) = self.get_socket_addrs(authority, server.clone()) { + return Ok((socket_addr, ttl)); + } + } + } + Err(HttpClientError::from_str( + ErrorKind::Connect, + "Can't find valid address", + )) + } + + fn get_socket_addrs( + &self, + authority: &str, + server: String, + ) -> Result<(Vec, u64), HttpClientError> { + let part: Vec<&str> = authority.split(':').collect(); + let host: &str = part[0]; + let port: u16 = part[1].parse().unwrap(); + println!("host: {:?}, port: {:?}", host, port); + let socket = UdpSocket::bind("0.0.0.0:0").unwrap(); + let mut socket_addrs = Vec::new(); + let mut ttl = u64::MAX; + + let query_message4 = self.build_dns_query(host, "IPV4"); + self.send_dns_query(&socket, &server, query_message4); + let response4 = self.receive_dns_response(&socket); + println!("{:?}", response4); + let (socket_addr, answer_ttl) = self + .parse_dns_response(&response4, host, port, "IPV4") + .unwrap(); + socket_addrs.extend(&socket_addr); + if answer_ttl < ttl { + ttl = answer_ttl; + } + + let query_message6 = self.build_dns_query(host, "IPV6"); + self.send_dns_query(&socket, &server, query_message6); + let response6 = self.receive_dns_response(&socket); + println!("{:?}", response6); + let (socket_addr, answer_ttl) = self + .parse_dns_response(&response6, host, port, "IPV6") + .unwrap(); + socket_addrs.extend(&socket_addr); + if answer_ttl < ttl { + ttl = answer_ttl; + } + Ok((socket_addrs, ttl)) + } + + fn send_dns_query(&self, socket: &UdpSocket, dns_server: &str, query_message: Vec) { + socket + .send_to(&query_message, dns_server) + .expect("Failed to send DNS query"); + } + + fn receive_dns_response(&self, socket: &UdpSocket) -> Vec { + let mut buffer = [0; 512]; + + let (_, _) = socket + .recv_from(&mut buffer) + .expect("Failed to receive DNS response"); + buffer.to_vec() + } + + fn parse_dns_response( + &self, + response: &[u8], + host: &str, + port: u16, + ip_type: &str, + ) -> Result<(Vec, u64), HttpClientError> { + let answer = &response[18 + host.len()..]; + let mut min_ttl = u64::MAX; + let mut socket_addrs = Vec::new(); + let mut index = 0; + while answer[index] != 0 { + let ttl: u64 = answer[index + 9] as u64 + + (answer[index + 8] as u64) * 256 + + (answer[index + 7] as u64) * 256 * 256 + + (answer[index + 6] as u64) * 256 * 256 * 256; + if ttl < min_ttl { + min_ttl = ttl; + } + if ip_type == "IPV4" { + let ip = Ipv4Addr::from([ + answer[index + 12], + answer[index + 13], + answer[index + 14], + answer[index + 15], + ]); + let socket_addr = SocketAddr::new(ip.into(), port); + socket_addrs.push(socket_addr); + index += 16; + } + if ip_type == "IPV6" { + let ip = Ipv6Addr::from([ + answer[index + 12] as u16 * 256 + answer[index + 13] as u16, + answer[index + 14] as u16 * 256 + answer[index + 15] as u16, + answer[index + 16] as u16 * 256 + answer[index + 17] as u16, + answer[index + 18] as u16 * 256 + answer[index + 19] as u16, + answer[index + 20] as u16 * 256 + answer[index + 21] as u16, + answer[index + 22] as u16 * 256 + answer[index + 23] as u16, + answer[index + 24] as u16 * 256 + answer[index + 25] as u16, + answer[index + 26] as u16 * 256 + answer[index + 27] as u16, + ]); + let socket_addr = SocketAddr::new(ip.into(), port); + socket_addrs.push(socket_addr); + index += 16 + 12; + } + } + Ok((socket_addrs, min_ttl)) + } + + fn build_dns_query(&self, domain_name: &str, ip_type: &str) -> Vec { + let mut message = vec![ + 0, 0, // 事务 ID + 1, 0, // 标准查询 + 0, 1, // 问题数 (同时查询 A 和 AAAA 记录) + 0, 0, // 回答数 + 0, 0, // 权威资源记录数 + 0, 0, // 额外资源记录数 */ + ]; + + for part in domain_name.split('.') { + message.push(part.len() as u8); // 每个部分的长度 + message.extend(part.bytes()); // 每个部分的内容 + } + message.push(0); // 结束字符串 + match ip_type { + "IPV4" => message.extend_from_slice(&[0, 1]), + "IPV6" => message.extend_from_slice(&[0, 28]), + _ => panic!(), + } + message.extend_from_slice(&[0, 1]); + message + } +} + +impl Resolver for DnsResolver { + fn resolve(&self, authority: &str) -> SocketFuture { + let authority = authority.to_string(); + let connector = self.connector.clone(); + let blocking = crate::runtime::spawn_blocking(move || { + let get_dns_manager = DnsManager::get_dns_manager(); + let map_lock = get_dns_manager.lock().unwrap(); + map_lock.clean_expired_entries(); + if let Some(addrs) = map_lock.map.lock().unwrap().get(&authority) { + let lock_inner = addrs.inner.lock().unwrap(); + if lock_inner.is_valid() { + return Ok(ResolvedAddrs::new(lock_inner.addr.clone().into_iter())); + } + } + match connector.retry(&authority) { + Ok((addrs, ttl)) => { + let dns_result = + DnsResult::new(addrs.clone(), Instant::now() + Duration::from_secs(ttl)); + map_lock.map.lock().unwrap().insert(authority, dns_result); + Ok(ResolvedAddrs::new(addrs.into_iter())) + } + Err(err) => Err(io::Error::new(io::ErrorKind::Other, err)), + } + }); + + Box::pin(DefaultDnsFuture { inner: blocking }) + } +} + +/// Doh resolver used by the `Client`. +#[derive(Default)] +pub struct DohResolver { + connector: DohConnector, // Performing DOH resolution +} + +impl DohResolver { + /// Create a default DohResolver. And set the doh server. + /// + /// # Examples + /// + /// ``` + /// use ylong_http_client::async_impl::DohResolver; + /// + /// let res = DohResolver::default().add_doh_server("https://1.12.12.12/dns-query"); + /// ``` + pub fn add_doh_server(mut self, doh_server: &str) -> Self { + self.connector.doh_servers.push(doh_server.to_string()); + self + } +} + +#[derive(Clone)] +struct DohConnector { + doh_servers: Vec, + max_retry_count: i32, +} + +impl Default for DohConnector { + fn default() -> Self { + DohConnector { + doh_servers: vec![DEFAULT_DOH_SERVER.to_string()], + max_retry_count: DEFAULT_MAX_RETRY_COUNT, + } + } +} + +impl DohConnector { + // Connects to the DOH server and retrieve DNS information. + async fn doh_connect( + &self, + authority: &str, + doh_server: String, + ) -> Result<(Vec, u64), HttpClientError> { + let part: Vec<&str> = authority.split(':').collect(); + let host: &str = part[0]; + let port: u16 = part[1].parse().unwrap(); + let url_4 = format!("{}?name={}&type=A", doh_server, host); + let url_6 = format!("{}?name={}&type=AAAA", doh_server, host); + let client_4 = Client::builder().build()?; + let client_6 = Client::builder().build()?; + let request_4 = Request::builder().url(&url_4).body(Body::empty())?; + let request_6 = Request::builder().url(&url_6).body(Body::empty())?; + let response_4 = client_4.request(request_4).await?; + let response_6 = client_6.request(request_6).await?; + let text_4 = response_4.text().await?; + let text_6 = response_6.text().await?; + let text = format!("{},{}", text_4, text_6); + Ok(get_info(&text, port)) + } + + async fn retry(&self, authority: &str) -> Result<(Vec, u64), HttpClientError> { + for _i in 0..self.max_retry_count { + for server in &self.doh_servers { + if let Ok((socket_addr, ttl)) = self.doh_connect(authority, server.clone()).await { + return Ok((socket_addr, ttl)); + } + } + } + Err(HttpClientError::from_str( + ErrorKind::Connect, + "Can't find valid address", + )) + } +} + +// Parses and extracts information from the DNS response text. +fn get_info(text: &str, port: u16) -> (Vec, u64) { + let mut ips = Vec::new(); + let mut start = 0; + let mut ttl = u64::MAX; + while let Some(answer_start) = find_answer(text, start) { + let (answer_end, answer_str) = get_answer_str(text, answer_start); + if let Some(ip) = extract_ip_from_answer(answer_str) { + if let Some(socket_addr) = create_socket_addr(ip, port) { + ips.push(socket_addr); + } + } + if let Some(answer_ttl) = extract_ttl_from_answer(answer_str) { + let answer_ttl = answer_ttl.parse().unwrap(); + if ttl > answer_ttl { + ttl = answer_ttl; + } + } + start = answer_end + 1; + } + (ips, ttl) +} + +fn find_answer(answer_section: &str, start: usize) -> Option { + answer_section[start..].find('{').map(|pos| start + pos) +} + +fn get_answer_str(answer_section: &str, answer_start: usize) -> (usize, &str) { + let answer_end = answer_section[answer_start..].find('}').unwrap() + answer_start; + (answer_end, &answer_section[answer_start..answer_end]) +} + +fn extract_ip_from_answer(answer_str: &str) -> Option { + if let Some(ip_pos) = answer_str.find("\"data\":\"") { + let ip_start = ip_pos + "\"data\":\"".len(); + if let Some(ip_end) = answer_str[ip_start..].find('\"') { + return Some(answer_str[ip_start..ip_start + ip_end].to_string()); + } + } + None +} + +fn extract_ttl_from_answer(answer_str: &str) -> Option { + if let Some(ttl_pos) = answer_str.find("\"TTL\":") { + let ttl_start = ttl_pos + "\"TTL\":".len(); + if let Some(ttl_end) = answer_str[ttl_start..].find(',') { + return Some(answer_str[ttl_start..ttl_start + ttl_end].to_string()); + } + } + None +} + +fn create_socket_addr(ip: String, port: u16) -> Option { + if let Ok(ipv4_addr) = Ipv4Addr::from_str(&ip) { + return Some(SocketAddr::new(IpAddr::V4(ipv4_addr), port)); + } + if let Ok(ipv6_addr) = Ipv6Addr::from_str(&ip) { + return Some(SocketAddr::new(IpAddr::V6(ipv6_addr), port)); + } + None +} + +impl Resolver for DohResolver { + fn resolve(&self, authority: &str) -> SocketFuture { + let authority = authority.to_string(); + let connector = self.connector.clone(); + let blocking = crate::runtime::spawn_blocking(move || { + let get_dns_manager = DnsManager::get_dns_manager(); + let map_lock = get_dns_manager.lock().unwrap(); + map_lock.clean_expired_entries(); + if let Some(addrs) = map_lock.map.lock().unwrap().get(&authority) { + let lock_inner = addrs.inner.lock().unwrap(); + if lock_inner.is_valid() { + return Ok(ResolvedAddrs::new(lock_inner.addr.clone().into_iter())); + } + } + #[cfg(feature = "ylong_base")] + let result = ylong_runtime::block_on(connector.retry(&authority)); + #[cfg(feature = "tokio_base")] + let result = tokio::runtime::Runtime::new() + .unwrap() + .block_on(connector.retry(&authority)); + match result { + Ok((addrs, ttl)) => { + let dns_result = + DnsResult::new(addrs.clone(), Instant::now() + Duration::from_secs(ttl)); + map_lock.map.lock().unwrap().insert(authority, dns_result); + Ok(ResolvedAddrs::new(addrs.into_iter())) + } + Err(err) => Err(io::Error::new(io::ErrorKind::Other, err)), + } + }); + Box::pin(DefaultDnsFuture { inner: blocking }) } } -#[cfg(feature = "tokio_base")] #[cfg(test)] -mod ut_dns_cache { - use std::sync::Arc; - use std::time::{Duration, Instant}; +mod ut_dns_manager { + use super::*; + + /// UT test case for `DnsManager::new` + /// + /// # Brief + /// 1. Creates a new `DnsManager` instance. + /// 2. Verifies the default `max_entries_len` is 30000. + /// 3. Sets and verifies a new `max_entries_len` of 1. + #[test] + fn ut_dns_manager_new() { + let manager = DnsManager::new(); + assert_eq!(manager.max_entries_len, 30000); + let mut map = manager.map.lock().unwrap(); + map.insert( + "example.com".to_string(), + DnsResult::new(vec![SocketAddr::from(([0, 0, 0, 1], 1))], Instant::now()), + ); + assert!(map.contains_key("example.com")); + } - use tokio::sync::Mutex; + /// UT test case for `DnsManager::clean_expired_entries` + /// + /// # Brief + /// 1. Creates a `DnsManager` instance and sets `max_entries_len` to 1. + /// 2. Adds two DNS results to the cache: one valid and one expired. + /// 3. Calls `clean_expired_entries` to remove expired entries. + /// 4. Verifies the expired entry is removed from the cache. + #[test] + fn ut_dns_manager_clean_cache() { + let mut manager = DnsManager::new(); + manager.max_entries_len = 1; + let mut map = manager.map.lock().unwrap(); + map.insert( + "example1.com".to_string(), + DnsResult::new( + vec![SocketAddr::from(([0, 0, 0, 1], 1))], + Instant::now() + Duration::from_secs(60), + ), + ); + map.insert( + "example2.com".to_string(), + DnsResult::new( + vec![SocketAddr::from(([0, 0, 0, 2], 2))], + Instant::now() - Duration::from_secs(60), + ), + ); + drop(map); + manager.clean_expired_entries(); + assert!(manager.map.lock().unwrap().contains_key("example1.com")); + assert!(!manager.map.lock().unwrap().contains_key("example2.com")); + } +} +#[cfg(test)] +mod ut_get_info { use super::*; - /// UT test cases for `DefaultDnsResolver::resolve`. + /// UT test case for `get_info` function with IPv4 address + /// + /// # Brief + /// 1. Provides a DNS response text for an IPv4 address. + /// 2. Calls `get_info` to extract addresses and TTL. + /// 3. Verifies the extracted address and TTL. + #[test] + fn ut_get_info_ipv4() { + let ipv4_text = + "{\"Status\":0,\"TC\":false,\"RD\":true,\"RA\":true,\"AD\":false,\"CD\":false,\ +\"Question\":[{\"name\":\"example.com.\",\"type\":1}],\"Answer\":[{\"name\":\"\ +example.com.\",\"type\":1,\"TTL\":3378,\"data\":\"93.184.215.14\"}]}"; + let (addrs, ttl) = get_info(ipv4_text, 0); + assert_eq!(addrs, vec![SocketAddr::from(([93, 184, 215, 14], 0))]); + assert_eq!(ttl, 3378); + } + + /// UT test case for `get_info` function with IPv6 address + /// + /// # Brief + /// 1. Provides a DNS response text for an IPv6 address. + /// 2. Calls `get_info` to extract addresses and TTL. + /// 3. Verifies the extracted address and TTL. + #[test] + fn ut_get_info_ipv6() { + let ipv6_text = + "{\"Status\":0,\"TC\":false,\"RD\":true,\"RA\":true,\"AD\":false,\"CD\":false,\ +\"Question\":[{\"name\":\"example.com.\",\"type\":28}]\"Answer\":[{\"name\":ex\ +ample.com.\",\"type\":28,\"TTL\":1466,\"data\":\"2606:2800:21f:cb07:6820:80da:\ +af6b:8b2c\"}]}"; + let (addrs, ttl) = get_info(ipv6_text, 0); + assert_eq!( + addrs, + vec![SocketAddr::from(( + [0x2606, 0x2800, 0x21f, 0xcb07, 0x6820, 0x80da, 0xaf6b, 0x8b2c], + 0 + ))] + ); + assert_eq!(ttl, 1466); + } + + /// UT test case for `get_info` function with both IPv4 and IPv6 addresses + /// + /// # Brief + /// 1. Provides a DNS response text with both IPv4 and IPv6 addresses. + /// 2. Calls `get_info` to extract the addresses and TTL. + /// 3. Verifies the extracted addresses and TTL. + #[test] + fn ut_get_info_both() { + let text = "{\"Status\":0,\"TC\":false,\"RD\":true,\"RA\":true,\"AD\":false,\"CD\":false,\ +\"Question\":[{\"name\":\"example.com.\",\"type\":1}],\"Answer\":[{\"name\":\"\ +example.com.\",\"type\":1,\"TTL\":3378,\"data\":\"93.184.215.14\"}]},{\"Status\ +\":0,\"TC\":false,\"RD\":true,\"RA\":true,\"AD\":false,\"CD\":false,\"Question\ +\":[{\"name\":\"example.com.\",\"type\":28}],\"Answer\":[{\"name\":\"example.c\ +om.\",\"type\":28,\"TTL\":1466,\"data\":\"2606:2800:21f:cb07:6820:80da:af6b:8b\ +2c\"}]}"; + let (addrs, ttl) = get_info(text, 0); + assert_eq!( + addrs, + vec![ + SocketAddr::from(([93, 184, 215, 14], 0)), + SocketAddr::from(( + [0x2606, 0x2800, 0x21f, 0xcb07, 0x6820, 0x80da, 0xaf6b, 0x8b2c], + 0 + )), + ] + ); + assert_eq!(ttl, 1466); + } + + /// UT test case for `get_info` function with some error response. /// /// # Brief - /// 1. Test that DNS resolution is correctly cached after the first - /// resolution. - /// 2. Ensure that a second resolution within TTL returns the same result - /// and is faster. - /// 3. Check that after TTL expiration, the resolver performs another - /// resolution and returns fresh results and is slower that second one. + /// 1. Provides a DNS response text with some error response. + /// 2. Calls `get_info` to extract the addresses and TTL. + /// 3. Verifies addresses is empty and TTL is the max of u64. + #[test] + fn ut_get_info_error() { + let error_text = "This is some error response."; + let (addrs, ttl) = get_info(error_text, 0); + assert_eq!(addrs, vec![]); + assert_eq!(ttl, u64::MAX); + } +} + +#[cfg(feature = "tokio_base")] +#[cfg(test)] +mod ut_resolver { + use std::time::Duration; + + use super::*; #[tokio::test] - async fn ut_dns_cache_test_async() { - let resolver = DefaultDnsResolver::new(Duration::from_millis(100)); - let domain = "example.com:0"; - let start = Instant::now(); - let addrs1 = resolver.resolve(domain).await; - let duration1 = start.elapsed(); + /// UT test case for `DefaultDnsResolver::resolve` + /// + /// # Brief + /// 1. Creates a default dns resolver with 50ms ttl. + /// 2. Calls resolve to get socket address twice. + /// 3. Verify second resolve are faster than first one. + /// 4. Verify second resolve result as same as first one. + async fn ut_defualt_dns_resolver_resolve() { + let authority = "example.com:0"; + let resolver = DefaultDnsResolver::new(Duration::from_millis(50)); + let addrs1 = resolver.resolve(authority).await; assert!(addrs1.is_ok()); - tokio::time::sleep(Duration::from_millis(80)).await; - let start = Instant::now(); - let addrs2 = resolver.resolve(domain).await; - let duration2 = start.elapsed(); + tokio::time::sleep(Duration::from_millis(10)).await; + let addrs2 = resolver.resolve(authority).await; assert!(addrs2.is_ok()); - if let (Ok(addrs1), Ok(addrs2)) = (addrs1, addrs2) { - let addrs1_vec: Vec = addrs1.collect(); - let addrs2_vec: Vec = addrs2.collect(); - assert_eq!(addrs1_vec, addrs2_vec); + if let (Ok(addr1), Ok(addr2)) = (addrs1, addrs2) { + let vec1: Vec = addr1.collect(); + let vec2: Vec = addr2.collect(); + assert_eq!(vec1, vec2); } - assert!(duration1 > duration2); - tokio::time::sleep(Duration::from_millis(80)).await; - let start = Instant::now(); - let _addrs3 = resolver.resolve(domain).await; - let duration3 = start.elapsed(); - assert!(duration3 > duration2); } - /// UT test cases for `DefaultDnsResolver::resolve` under multiple - /// concurrent requests. - /// - /// # Brief - /// 1. Test that multiple concurrent DNS resolution requests return the same - /// result. - /// 2. Ensure that only the first request performs the resolution and others - /// wait for the result. - /// 3. Check that subsequent requests are faster, as they use the cached - /// result. #[tokio::test] - async fn ut_dns_cache_test_multi_async() { - let domain = "example.com:0"; - let resolver = Arc::new(Mutex::new(DefaultDnsResolver::new(Duration::from_millis( - 500, - )))); - let first_duration = Arc::new(Mutex::new(None::)); - let first_addrs = Arc::new(Mutex::new(None::>)); - let mut handles = vec![]; - for _ in 0..3 { - let resolver = Arc::clone(&resolver); - let first_duration = Arc::clone(&first_duration); - let first_addrs = Arc::clone(&first_addrs); - let handle = tokio::spawn(async move { - let resolver = resolver.lock().await; - let start = Instant::now(); - let addrs = resolver.resolve(domain).await; - let duration = start.elapsed(); - assert!(addrs.is_ok()); - let addrs = addrs.unwrap().collect::>(); - let mut first_duration_locked = first_duration.lock().await; - if first_duration_locked.is_none() { - *first_duration_locked = Some(duration); - let mut first_addrs_locked = first_addrs.lock().await; - *first_addrs_locked = Some(addrs); - } else { - let first_duration_locked = first_duration_locked.as_ref().unwrap(); - assert!(*first_duration_locked > duration); - - let first_addrs_locked = first_addrs.lock().await; - assert!(*first_addrs_locked.as_ref().unwrap() == addrs); - } - }); - handles.push(handle); + async fn ut_dns_resolver_resolve() { + let authority = "example.com:0"; + let resolver = DnsResolver::default().add_dns_server("223.5.5.5:53"); + let addrs1 = resolver.resolve(authority).await; + assert!(addrs1.is_ok()); + tokio::time::sleep(Duration::from_millis(10)).await; + let addrs2 = resolver.resolve(authority).await; + assert!(addrs2.is_ok()); + if let (Ok(addr1), Ok(addr2)) = (addrs1, addrs2) { + let vec1: Vec = addr1.collect(); + let vec2: Vec = addr2.collect(); + assert_eq!(vec1, vec2); } - for handle in handles { - handle.await.unwrap(); + } + #[tokio::test] + /// UT test case for `DohResolver::resolve` + /// + /// # Brief + /// 1. Creates a doh resolver and set doh server. + /// 2. Calls resolve to get socket address twice. + /// 3. Verify second resolve are faster than first one. + /// 4. Verify second resolve result as same as first one. + async fn ut_doh_resolver_resolve() { + let authority = "example.com:0"; + let resolver = DohResolver::default().add_doh_server("https://1.12.12.12/dns-query"); + let addrs1 = resolver.resolve(authority).await; + assert!(addrs1.is_ok()); + tokio::time::sleep(Duration::from_millis(10)).await; + let addrs2 = resolver.resolve(authority).await; + assert!(addrs2.is_ok()); + if let (Ok(addr1), Ok(addr2)) = (addrs1, addrs2) { + let vec1: Vec = addr1.collect(); + let vec2: Vec = addr2.collect(); + assert_eq!(vec1, vec2); } } } #[cfg(feature = "ylong_base")] #[cfg(test)] -mod ut_dns_cache { - use std::sync::Arc; - use std::time::{Duration, Instant}; - - use ylong_runtime::sync::Mutex; +mod ut_resolver { + use std::time::Duration; use super::*; - - /// UT test cases for `DefaultDnsResolver::resolve`. + /// UT test case for `DefaultDnsResolver::resolve` /// /// # Brief - /// 1. Test that DNS resolution is correctly cached after the first - /// resolution. - /// 2. Ensure that a second resolution within TTL returns the same result - /// and is faster. - /// 3. Check that after TTL expiration, the resolver performs another - /// resolution and returns fresh results and is slower that second one. + /// 1. Creates a default dns resolver with 50ms ttl. + /// 2. Calls resolve to get socket address twice. + /// 3. Verify second resolve are faster than first one. + /// 4. Verify second resolve result as same as first one. #[test] - fn ut_dns_cache_test() { - ylong_runtime::block_on(ut_dns_cache_test_async()); + fn ut_dns_resolver_resolve() { + ylong_runtime::block_on(ut_dns_resolver_resolve_async()); } - async fn ut_dns_cache_test_async() { - let resolver = DefaultDnsResolver::new(Duration::from_millis(100)); - let domain = "example.com:0"; - let start = Instant::now(); - let addrs1 = resolver.resolve(domain).await; - let duration1 = start.elapsed(); + async fn ut_dns_resolver_resolve_async() { + let authority = "example.com:0"; + let resolver = DefaultDnsResolver::new(Duration::from_millis(50)); + let start1 = Instant::now(); + let addrs1 = resolver.resolve(authority).await; + let duration1 = start1.elapsed(); assert!(addrs1.is_ok()); - ylong_runtime::time::sleep(Duration::from_millis(80)).await; - let start = Instant::now(); - let addrs2 = resolver.resolve(domain).await; - let duration2 = start.elapsed(); + ylong_runtime::time::sleep(Duration::from_millis(10)).await; + let start2 = Instant::now(); + let addrs2 = resolver.resolve(authority).await; + let duration2 = start2.elapsed(); + assert!(duration1 > duration2); assert!(addrs2.is_ok()); - if let (Ok(addrs1), Ok(addrs2)) = (addrs1, addrs2) { - let addrs1_vec: Vec = addrs1.collect(); - let addrs2_vec: Vec = addrs2.collect(); - assert_eq!(addrs1_vec, addrs2_vec); + if let (Ok(addr1), Ok(addr2)) = (addrs1, addrs2) { + let vec1: Vec = addr1.collect(); + let vec2: Vec = addr2.collect(); + assert_eq!(vec1, vec2); } - assert!(duration1 > duration2); - ylong_runtime::time::sleep(Duration::from_millis(80)).await; - let start = Instant::now(); - let _addrs3 = resolver.resolve(domain).await; - let duration3 = start.elapsed(); - assert!(duration3 > duration2); } - /// UT test cases for `DefaultDnsResolver::resolve` under multiple - /// concurrent requests. + /// UT test case for `DohResolver::resolve` /// /// # Brief - /// 1. Test that multiple concurrent DNS resolution requests return the same - /// result. - /// 2. Ensure that only the first request performs the resolution and others - /// wait for the result. - /// 3. Check that subsequent requests are faster, as they use the cached - /// result. + /// 1. Creates a doh resolver and set doh server. + /// 2. Calls resolve to get socket address twice. + /// 3. Verify second resolve are faster than first one. + /// 4. Verify second resolve result as same as first one. #[test] - fn ut_dns_cache_test_multi() { - ylong_runtime::block_on(ut_dns_cache_test_multi_async()); - } - - async fn ut_dns_cache_test_multi_async() { - let domain = "example.com:0"; - let resolver = Arc::new(Mutex::new(DefaultDnsResolver::new(Duration::from_millis( - 500, - )))); - let first_duration = Arc::new(Mutex::new(None::)); - let first_addrs = Arc::new(Mutex::new(None::>)); - let mut handles = Vec::new(); - for _ in 0..3 { - let resolver = Arc::clone(&resolver); - let first_duration = Arc::clone(&first_duration); - let first_addrs = Arc::clone(&first_addrs); - let handle = ylong_runtime::spawn(async move { - let resolver = resolver.lock().await; - let start = Instant::now(); - let addrs = resolver.resolve(domain).await; - let duration = start.elapsed(); - assert!(addrs.is_ok()); - let addrs = addrs.unwrap().collect::>(); - let mut first_duration_locked = first_duration.lock().await; - if first_duration_locked.is_none() { - *first_duration_locked = Some(duration); - let mut first_addrs_locked = first_addrs.lock().await; - *first_addrs_locked = Some(addrs); - } else { - let first_duration_locked = first_duration_locked.as_ref().unwrap(); - assert!(*first_duration_locked > duration); - let first_addrs_locked = first_addrs.lock().await; - assert!(*first_addrs_locked.as_ref().unwrap() == addrs); - } - }); - handles.push(handle); - } - for handle in handles { - handle.await.unwrap(); + fn ut_doh_resolver_resolve() { + ylong_runtime::block_on(ut_doh_resolver_resolve_async()); + } + + async fn ut_doh_resolver_resolve_async() { + let authority = "example.com:0"; + let resolver = DohResolver::default().add_doh_server("https://1.12.12.12/dns-query"); + let start1 = Instant::now(); + let addrs1 = resolver.resolve(authority).await; + let duration1 = start1.elapsed(); + assert!(addrs1.is_ok()); + ylong_runtime::time::sleep(Duration::from_millis(10)).await; + let start2 = Instant::now(); + let addrs2 = resolver.resolve(authority).await; + let duration2 = start2.elapsed(); + assert!(duration1 > duration2); + assert!(addrs2.is_ok()); + if let (Ok(addr1), Ok(addr2)) = (addrs1, addrs2) { + let vec1: Vec = addr1.collect(); + let vec2: Vec = addr2.collect(); + assert_eq!(vec1, vec2); } } } diff --git a/ylong_http_client/src/async_impl/mod.rs b/ylong_http_client/src/async_impl/mod.rs index 572a636..d361f06 100644 --- a/ylong_http_client/src/async_impl/mod.rs +++ b/ylong_http_client/src/async_impl/mod.rs @@ -60,4 +60,6 @@ pub use ylong_http::body::{MultiPart, Part}; /// Client Adapter. pub type Client = client::Client; -pub use dns::{Addrs, DefaultDnsResolver, Resolver, SocketFuture, StdError}; +pub use dns::{ + Addrs, DefaultDnsResolver, DnsResolver, DohResolver, Resolver, SocketFuture, StdError, +}; diff --git a/ylong_http_client/src/lib.rs b/ylong_http_client/src/lib.rs index f4bbed2..d1ec7b3 100644 --- a/ylong_http_client/src/lib.rs +++ b/ylong_http_client/src/lib.rs @@ -85,13 +85,14 @@ pub(crate) mod runtime { io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf}, net::TcpStream, sync::{OwnedSemaphorePermit as SemaphorePermit, Semaphore}, - task::JoinHandle, + task::{spawn_blocking, JoinHandle}, time::{sleep, timeout, Sleep}, }; #[cfg(feature = "ylong_base")] pub(crate) use ylong_runtime::{ io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf}, net::TcpStream, + spawn_blocking, sync::Semaphore, task::JoinHandle, time::{sleep, timeout, Sleep}, -- Gitee