1 Star 0 Fork 1

饭饭/DNS缓存服务器

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
dns_local_server.py 10.79 KB
一键复制 编辑 原始数据 按行查看 历史
"""
-------------------------------------
# @Time : 2021/1/11 17:41
# @Author : ls
# @File : dns_local_server.py
# @IDE: PyCharm
--------------------------------------
"""
from utils.read_config import ConfigRead
from utils.package_tools import PackageTools
from dns_client import DNSClient
import logging
import time
from utils.threading_lock import ThreadingLock
import socket
import threading
from math import floor
from utils.cache_tools import CacheTools
class DNSLocalServer(object):
lk = ThreadingLock.lk
TOTAL_COUNT = 0
SUCCESS_COUNT = 0
FAIL_COUNT = 0
TOTAL_SAVE_TIME = 0
RUNNING_TIME = time.time()
CURRENT_CACHE_NUMBER = 0
ct = CacheTools()
def __init__(self):
self.logger = logging.getLogger('test')
self.conf = ConfigRead()
self.LOCAL_BIND_ADDRESS = self.conf.read_option_value('SERVER_CONFIG', 'LOCAL_DNS_SERVER_BIND_ADDRESS')
# self.LOCAL_BIND_ADDRESS = LOCAL_BIND_ADDRESS # 本地DNS服务器监听地址:端口
self.LOCAL_BIND_PORT = int(self.conf.read_option_value('SERVER_CONFIG', 'LOCAL_DNS_SERVER_BIND_PORT'))
self.SERVER_INFO = (self.LOCAL_BIND_ADDRESS, self.LOCAL_BIND_PORT)
self.BUFFER_SIZE = 1024
def dns_server_start(self):
try:
self.lk.acquire()
self.server = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.server.bind(self.SERVER_INFO)
self.logger.info('\nDNS SERVER START SUCCESS\tIP={}\tPORT={}\n{}'
.format(self.LOCAL_BIND_ADDRESS, self.LOCAL_BIND_PORT, '='*100))
self.ct = CacheTools()
except Exception as e:
self.logger.error(e)
return False
else:
return self.server
finally:
self.lk.release()
def capture_dns_pkt(self):
while True:
try:
pkg, client_addr = self.server.recvfrom(self.BUFFER_SIZE)
self.pt = PackageTools()
self.pt.package = pkg
self.domain_name = self.pt.domain_name()
self.id = self.pt.transaction_id()
self.transaction_id = self.id
query_type = self.pt.query_type()
if query_type == 'PTR':
self.reply_PTR(client_addr, query_type)
elif query_type == 'AAAA':
self.reply_ipv6_query(self.domain_name, client_addr, query_type)
elif query_type == 'A':
if self.ct.has_cache(self.domain_name):
"""走直接响应并更新查询的路径"""
threading.Thread(target=self.query_answer_from_cache,
args=(self.domain_name, client_addr, pkg, query_type)).start()
else:
"""走向公网查询并新建缓存的路径"""
threading.Thread(target=self.query_answer_from_server,
args=(self.domain_name, client_addr, pkg, query_type)).start()
else:
"""非A,AAAA,PTR查询类型,则直接转发"""
self.query_forward(self.domain_name, client_addr, pkg, query_type)
except Exception as e:
self.logger.error(e)
time.sleep(0.001)
def generate_response_pkg(self, section):
"""读取缓存响应数据,并生成响应报文"""
response_data_cache = self.ct.read_cache(section, 'response_data') # 读取缓存中的响应数据
response_data_tmp = response_data_cache.replace('(', '').replace(')', '') # 去掉记录中的括号
rsp_data_list = list(map(int, response_data_tmp.split(','))) # 将配置文件中读取到的缓存数据转为list
self.pt.package = self.transaction_id + rsp_data_list
return self.pt.pack_package()
def query_answer_from_server(self, section, client_info, pkg, query_type):
"""从公网DNS服务器查询结果,并添加缓存"""
client = DNSClient()
time_init = time.time()
response_data_from_remote_server = client.query_dns_from_remote_dns_server(pkg)
ret = response_data_from_remote_server[0]
response_data = response_data_from_remote_server[1]
time_diff = round((time.time() - time_init)*1000)
self.lk.acquire()
self.TOTAL_COUNT += 1
if ret:
self.server.sendto(response_data, client_info)
self.pt.package = response_data
if self.ct.has_cache(section):
self.ct.update_query_times(section)
ext_msg = None
else:
self.ct.add_cache(section, self.pt.response_data())
ext_msg = {'info': 'Add cache success'}
self.ct.write_cache()
self.SUCCESS_COUNT += 1
else:
ext_msg = {'error': response_data, 'debug': 'Cache will not be written'}
self.FAIL_COUNT += 1
self.summary_print(section, client_info[0], query_type, 'server', time_diff, ext_msg)
self.lk.release()
def query_answer_from_cache(self, section, client_info, pkg, query_type):
"""从缓存中读取数据响应请求,并向公网查询并更新缓存"""
self.lk.acquire()
response_data_cache = self.generate_response_pkg(section)
self.server.sendto(response_data_cache, client_info) # 使用缓存响应客户端请求
self.ct.update_query_times(section)
self.lk.release()
client = DNSClient()
time_init = time.time()
response_data_from_remote_server = client.query_dns_from_remote_dns_server(pkg)
ret = response_data_from_remote_server[0]
response_data = response_data_from_remote_server[1]
time_diff = round((time.time() - time_init) * 1000)
self.lk.acquire()
self.TOTAL_COUNT += 1
if ret:
self.pt.package = response_data_cache
old_response_data = self.pt.response_data()
self.pt.package = response_data
new_response_data = self.pt.response_data()
if self.ct.update_response(section, old_rsp_data=old_response_data, new_rsp_data=new_response_data):
ext_msg = {'info': 'response_data updated'}
else:
ext_msg = None
self.SUCCESS_COUNT += 1
else:
self.TOTAL_SAVE_TIME = round((self.TOTAL_SAVE_TIME*1000 + time_diff) / 1000, 3)
ext_msg = {'error': response_data}
self.FAIL_COUNT += 1
self.summary_print(section, client_info[0], query_type, 'cache', time_diff, ext_msg)
self.ct.write_cache()
self.lk.release()
def reply_ipv6_query(self, section, client_info, query_type):
"""响应ipv6查询,直接恢复查无此记录"""
self.lk.acquire()
self.TOTAL_COUNT += 1
self.server.sendto(self.pt.ipv6_reply(), client_info)
self.SUCCESS_COUNT += 1
self.summary_print(section, client_info[0], query_type)
self.lk.release()
def reply_PTR(self, client_info, query_type):
"""响应PTR查询"""
self.lk.acquire()
self.TOTAL_COUNT += 1
self.server.sendto(self.pt.PTR_reply(), client_info)
self.SUCCESS_COUNT += 1
self.summary_print('PTR', client_info[0], query_type)
self.lk.release()
def query_forward(self, section, client_info, pkg, query_type):
client = DNSClient()
time_init = time.time()
response_data_from_remote_server = client.query_dns_from_remote_dns_server(pkg)
ret = response_data_from_remote_server[0]
response_data = response_data_from_remote_server[1]
time_diff = round((time.time() - time_init) * 1000)
self.lk.acquire()
self.TOTAL_COUNT += 1
if ret:
self.server.sendto(response_data, client_info)
ext_msg = None
self.SUCCESS_COUNT += 1
else:
ext_msg = {'error': response_data}
self.FAIL_COUNT += 1
self.summary_print(section, client_info[0], query_type, 'forward', time_diff, ext_msg)
self.lk.release()
def summary_print(self, section, client_info, query_type, reply_source='', time_diff=0, ext_msg=None):
"""打印DNS查询概要信息"""
summary_info = 'DOMAIN_NAME={}\tCLIENT_ADDRESS={}\tTYPE={}'.format(section, client_info, query_type)
summary_info = '{:^{}}'.format(summary_info, 100)
self.logger.info('{}\n{}\n{}'.format('+' * 68, '=' * 100, summary_info))
if reply_source == '':
pass
elif reply_source == 'server':
self.logger.info('Query from public DNS server')
elif reply_source == 'cache':
self.logger.info('Reply with cache')
elif reply_source == 'forward':
self.logger.info("Forward to public DNS server")
else:
self.logger.warning('error [reply source] parameter')
if ext_msg is not None:
for i in ext_msg.keys():
if i == 'debug':
self.logger.debug(ext_msg[i])
elif i == 'info':
self.logger.info(ext_msg[i])
elif i == 'warn':
self.logger.warning(ext_msg[i])
elif i == 'error':
self.logger.error(ext_msg[i])
elif i == 'critical':
self.logger.critical(ext_msg[i])
else:
self.logger.warning('unkown log level' + ext_msg[i])
self.ct.read_cache_file()
runtime_and_cacheNumber = 'Running_Time:{}\tCache_Numbers:{}'.format(self.running_time(), len(self.ct.sections()))
output_msg1 = 'Response_time={}ms\tTotal_save={}s'.format(time_diff, self.TOTAL_SAVE_TIME)
output_msg1 = '{:^{}}'.format(output_msg1, 100)
output_msg2 = 'TOTAL_COUNT={}\tSUCCESS_COUNT={}\tFAIL_COUNT={}'.format(self.TOTAL_COUNT, self.SUCCESS_COUNT,self.FAIL_COUNT)
output_msg2 = '{:^{}}'.format(output_msg2, 100)
self.logger.info('{}\n{}\n{}\n{}'.format(runtime_and_cacheNumber, output_msg1, output_msg2, '=' * 100))
def running_time(self):
"""
:return:程序运行时间
"""
running_time = round(time.time()-self.RUNNING_TIME)
if running_time < 60:
return str(running_time)+'s'
elif running_time < 3600:
return str(floor(running_time/60))+'m'+str(running_time % 60)+'s'
elif running_time < 86400:
return str(floor(running_time/3600))+'h'+str(floor(running_time/60%60))+'m'+str(running_time%60)+'s'
else:
return str(floor(running_time/86400))+'d'+str(floor(running_time%86400/3600))+'h'+str(floor(running_time%86400/60%60))+'m'+str(running_time%60)+'s'
if __name__ == '__main__':
dls = DNSLocalServer()
dls.dns_server_start()
dls.capture_dns_pkt()
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/nicol/dns-cache-server.git
git@gitee.com:nicol/dns-cache-server.git
nicol
dns-cache-server
DNS缓存服务器
master

搜索帮助