diff --git a/lib/core-net/client/connect3.c b/lib/core-net/client/connect3.c index 9d669e3b5a1c69a0884fe249475cd5870bce0fad..2a9fa6b83a1f11e6e110aa0297385f7493fb3a13 100644 --- a/lib/core-net/client/connect3.c +++ b/lib/core-net/client/connect3.c @@ -24,6 +24,108 @@ #include "private-lib-core.h" +#define MIN(a, b) ((a) < (b) ? (a) : (b)) + +size_t get_res_size(struct addrinfo *res) +{ + size_t size = 0; + for (struct addrinfo *p = res; p; p = p->ai_next) { + if (p->ai_family != AF_INET && p->ai_family != AF_INET6) { + continue; + } + ++size; + } + return size; +} + +struct addrinfo **get_dns_res(struct addrinfo *res, sa_family_t family, size_t totalLength, size_t *size) +{ + struct addrinfo **temp = (struct addrinfo **)(malloc(sizeof(struct addrinfo *) * totalLength)); + if (!temp) { + return NULL; + } + size_t index = 0; + for (struct addrinfo *p = res; p; p = p->ai_next) { + if (p->ai_family == family) { + temp[index] = p; + ++index; + } + } + *size = index; + return temp; +} + +sa_family_t change_family(sa_family_t nowFamily) +{ + if (nowFamily == AF_INET6) { + return AF_INET; + } + return AF_INET6; +} + +struct addrinfo *sort_dns(struct addrinfo *res) +{ + size_t totalLength = get_res_size(res); + if (totalLength == 0) { + return NULL; + } + + size_t ipv6Size = 0; + struct addrinfo **ipv6Dns = get_dns_res(res, AF_INET6, totalLength, &ipv6Size); + size_t ipv4Size = 0; + struct addrinfo **ipv4Dns = get_dns_res(res, AF_INET, totalLength, &ipv4Size); + if (ipv4Dns == NULL && ipv6Dns == NULL) { + return NULL; + } + + for (size_t i = 0; i < ipv6Size; ++i) { + ipv6Dns[i]->ai_next = NULL; + } + for (size_t i = 0; i < ipv4Size; ++i) { + ipv4Dns[i]->ai_next = NULL; + } + + size_t ipv6Index = 0; + size_t ipv4Index = 0; + sa_family_t now = AF_INET6; + + struct addrinfo *head = (struct addrinfo *)malloc(sizeof(struct addrinfo)); + memset(head, 0, sizeof(struct addrinfo)); + struct addrinfo *next = head; + + size_t minSize = MIN(ipv6Size, ipv4Size); + size_t index = 0; + while (index < 2 * minSize) { + if (now == AF_INET6) { + next->ai_next = ipv6Dns[ipv6Index++]; + } else { + next->ai_next = ipv4Dns[ipv4Index++]; + } + ++index; + now = change_family(now); + next = next->ai_next; + } + while (ipv6Index < ipv6Size) { + next->ai_next = ipv6Dns[ipv6Index++]; + ++index; + next = next->ai_next; + } + while (ipv4Index < ipv4Size) { + next->ai_next = ipv4Dns[ipv4Index++]; + ++index; + next = next->ai_next; + } + struct addrinfo *result = head->ai_next; + free(head); + if (ipv6Dns) { + free(ipv6Dns); + } + if (ipv4Dns) { + free(ipv4Dns); + } + return result; +} + void lws_client_conn_wait_timeout(lws_sorted_usec_list_t *sul) { @@ -183,6 +285,7 @@ lws_client_connect_3_connect(struct lws *wsi, const char *ads, lws_conmon_append_copy_new_dns_results(wsi, result); #endif + result = sort_dns((struct addrinfo *)result); lws_sort_dns(wsi, result); #if defined(LWS_WITH_SYS_ASYNC_DNS) lws_async_dns_freeaddrinfo(&result); @@ -244,7 +347,7 @@ lws_client_connect_3_connect(struct lws *wsi, const char *ads, * If the connection failed, the OS-level errno may be * something like EINPROGRESS rather than the actual problem * that prevented a connection. This value will represent the - * ¡°real¡± problem that we should report to the caller. + * "real" problem that we should report to the caller. */ int real_errno = 0;