diff --git a/src/knowledge_extractor/api_server.py b/src/knowledge_extractor/api_server.py
new file mode 100644
index 0000000000000000000000000000000000000000..a1eaab9e568bd179ef6e67247ded689b07468f5c
--- /dev/null
+++ b/src/knowledge_extractor/api_server.py
@@ -0,0 +1,390 @@
+from openai import OpenAI
+import re
+import httpx
+import yaml
+
+#读取配置文件
+def load_llm_config(file_path):
+ with open(file_path, 'r', encoding='utf-8') as file:
+ config = yaml.safe_load(file)
+ return config.get('llm', {})
+
+llm_config = load_llm_config("llm_config.yaml")
+
+# 创建 OpenAI 客户端实例
+client = OpenAI(
+ base_url=llm_config.get('base_url', ''),
+ api_key=llm_config.get('api_key', ''),
+ http_client=httpx.Client(verify=False)
+)
+role_prompt = "你是一个专业的文本分析专家,擅长从复杂的技术文档中精准提取关键的应用参数信息,并以清晰、规范的方式呈现提取结果。"
+role_prompt2 = "你是一位资深的性能调优专家,拥有丰富的操作系统应用参数调优经验,并能给出有效的建议。"
+
+json_example = """
+[
+ {
+ "name": "innodb_write_io_threads",
+ "info": {
+ "desc": "The number of I/O threads for write operations in InnoDB. ",
+ "needrestart": "true",
+ "type": "continuous",
+ "min_value":2,
+ "max_value":200,
+ "default_value":4,
+ "dtype": "int",
+ "version":"8.0",
+ "related_param":[],
+ "options":null
+ }
+ },
+ {
+ "name": "innodb_read_io_threads",
+ "info": {
+ "desc": "MySQL [mysqld] parameters 'innodb_read_io_threads'.",
+ "needrestart": "true",
+ "type": "continuous",
+ "min_value": 1,
+ "max_value": 64,
+ "default_value": 4,
+ "dtype": "int",
+ "version":"8.0",
+ "related_param":[],
+ "options":null
+ }
+ }
+]
+"""
+
+def get_messages(
+ role_prompt: str,
+ history: str,
+ usr_prompt: str,
+) -> list:
+ """
+ 构建消息列表,用于与OpenAI模型进行对话。
+
+ Parameters:
+ role_prompt (str): 系统角色提示。
+ history (str): 历史对话内容。
+ usr_prompt (str): 用户当前请求。
+
+ Returns:
+ list: 包含系统角色提示、历史对话和当前请求的消息列表。
+ """
+ messages = []
+ if role_prompt != "":
+ messages.append({"role": "system", "content": role_prompt})
+ if len(history) > 0:
+ messages.append({"role": "assistant", "content":history})
+ if usr_prompt != "":
+ messages.append({"role": "user", "content": usr_prompt})
+ return messages
+
+
+
+def parameter_official_knowledge_preparation(text: str, app: str)-> str:
+ """
+ 从官方文档文本中提取参数信息并返回JSON格式的参数列表。
+
+ Parameters:
+ text (str): 官方文档文本内容。
+ example (str): 示例JSON格式。
+ app (str): 应用程序名称。
+
+ Returns:
+ str: 包含提取参数信息的JSON字符串,若提取失败则返回None。
+ """
+ prompt = '''
+ 你是一个专业的文本分析助手,擅长从技术文档中精准提取关键信息。现在,我给你一段关于{app}数据库参数配置相关的文本内容。你的任务是从这段文本中提取所有包含的{app}配置参数的信息,并按照指定格式输出。
+
+ <文本内容>
+ {text}
+
+ <任务要求>
+ 请将提取的信息以JSON格式返回,其中每个参数的信息应包含以下字段(如果文本中未提及某字段,请设置为null,不要自行生成信息):
+ name(参数名称)
+ desc(参数描述)
+ needrestart(设置参数后是否需要重启,布尔值)
+ type(参数是否连续,可以为continuous或discrete)
+ min_value(参数的最小值)
+ max_value(参数的最大值)
+ default_value(参数的默认值)
+ dtype(参数的数据类型,如int、string、boolean、float,该字段请只在给定的这几个中选择)
+ version(参数的生效版本)
+ related_param(与该参数存在关联的参数)
+ options(参数的离散值集合)
+
+ <注意事项>
+ 如果参数取值为连续值,请将options字段设置为null。
+ 如果参数取值为离散值(如ON/OFF),请将min_value字段和max_value字段均设置为null,将options设置为离散值集合(如["ON", "OFF"])。
+ 如果文本中未提及某个字段,请在JSON中将该字段设置为null。
+ 如果文本中未提及某个参数,请不要在JSON中输出该参数。
+ 最大值和最小值可以从“Permitted values”或“Range”等描述中获取。
+ needrestart字段可以参考Dynamic内容,Dynamic表示是否能动态调整该参数,该值为yes时needrestart值为false;该值为no时needrestart为true。
+ related_param字段可以在参数的描述中查找,若描述中提到其他的参数,则可以进一步判断是不是一个相关参数,如果是,请在该字段用列表输出。若没有,输出一个空列表。
+
+ <输出示例>
+ 请按照以下格式输出JSON数据:
+ {example}
+
+ '''
+ example = json_example
+ messages = get_messages(role_prompt,[],prompt.format(app=app, example=example, text=text))
+ chat_completion = client.chat.completions.create(
+ messages=messages,
+ model=llm_config.get('model', ''),
+ temperature=1
+ )
+
+ # 打印响应内容
+ print(chat_completion.choices[0].message.content)
+ ans = chat_completion.choices[0].message.content
+ ans = re.sub(r".*?", "", ans, flags=re.DOTALL)
+ json_pattern = r'\[.*\]'
+ json_str = re.search(json_pattern, ans, re.DOTALL)
+
+ if json_str:
+ # 提取匹配到的JSON字符串
+ json_str = json_str.group(0)
+ return json_str
+ else:
+ print("没有找到JSON数据")
+ return
+
+
+#从文本中提取参数信息
+def parameter_knowledge_preparation(text: str, params: list, app: str) -> str:
+ """
+ 从web等文本中提取给定参数列表中的参数信息并返回JSON格式的参数列表。
+
+ Parameters:
+ text (str): 文本内容。
+ example (str): 示例JSON格式。
+ params (list): 参数列表。
+ app (str): 应用程序名称。
+
+ Returns:
+ str: 包含提取参数信息的JSON字符串,若提取失败则返回None。
+ """
+ prompt = '''
+ 你是一个专业的文本分析助手,擅长从技术文档中精准提取关键信息。现在,我给你一段关于{app}参数配置相关的文本内容。你的任务是从这段文本中提取以下参数的信息,请将给定参数列表的参数信息尽可能详细地提取。
+
+ <文本内容>
+ {text}
+
+ <任务要求>
+ 注意,我只需要这些给定参数的信息,其他参数的信息请不要输出:
+ 给定的参数列表是:{params}
+ 请将提取的信息以 JSON 格式返回,其中每个参数的信息应包含在对应的键下。如果文本中没有提到的参数,请不要在 JSON 中将该参数输出。
+
+ 参考的执行步骤是:
+ 1. 首先匹配是否有给定参数列表中的{app}参数
+ 2. 将该参数的值或者描述等信息找到,需要的信息包括参数名称(name),参数描述(desc),参数设置后是否需要重启(needrestart),参数是否连续(type),参数最小值(min_value),参数最大值(max_value),参数默认值(default_value),参数的数据类型(dtype),参数的生效版本(version),与该参数存在关联的参数(related_param),参数的离散值集合(options)。注意:Dynamic表示是否能动态调整该参数,其为yes时needrestart为false。
+ 3. 注意如果参数取值为连续值,请将options字段设置为null。如果参数取值为离散值(如ON/OFF),请将min_value字段和max_value字段均设置为null,将options设置为离散值集合(如["ON", "OFF"])。
+ 4. 将找到的信息保存,未找到的信息项设置为null,请不要自己生成相关的信息,要在文本中查找。
+ 5. 将你从文本中获取到的信息以 json 格式输出,一个输出的示例为:
+ {example}
+ 其中没有获取到的信息请设置为null
+
+ 注意:只输出一个包括参数信息的 json。如果文本中没有提到的参数,请不要在 JSON 中输出。不在example中的信息项,请不要输出。
+
+'''
+ example = json_example
+ messages = get_messages(role_prompt,[],prompt.format(app=app, params=params,example=example, text=text))
+ chat_completion = client.chat.completions.create(
+ messages=messages,
+ model=llm_config.get('model', ''),
+ temperature=0.1,
+ )
+
+ # 打印响应内容
+ print(chat_completion.choices[0].message.content)
+ ans = chat_completion.choices[0].message.content
+ ans = re.sub(r".*?", "", ans, flags=re.DOTALL)
+
+ json_pattern = r'\[.*\]'
+ json_str = re.search(json_pattern, ans, re.DOTALL)
+
+ if json_str:
+ # 提取匹配到的JSON字符串
+ json_str = json_str.group(0)
+ return json_str
+ else:
+ print("没有找到JSON数据")
+ return
+
+
+#参数信息的gpt来源
+def get_gpt_knowledge(params: list, app: str) -> str:
+ """
+ GPTT获取给定参数列表中每个参数的详细信息,包括名称、描述、获取命令、设置命令等,并返回JSON格式的参数列表。
+
+ Parameters:
+ params (list): 参数列表。
+ example (str): 示例JSON格式。
+ app (str): 应用程序名称。
+
+ Returns:
+ str: 包含参数信息的JSON字符串,若获取失败则返回None。
+ """
+ prompt = '''
+ 请根据以下{app}参数列表,详细描述每个参数的相关知识,包括包括参数名称(name),参数描述(desc),参数获取命令(get),参数设置命令(set),参数设置后是否需要重启(needrestart),参数是否连续(type),参数最小值(min_value),参数最大值(max_value),参数默认值(default_value),参数的数据类型(dtype),参数的生效版本(version),参数的互相关联参数(related_param),参数的离散值集合(options)。
+ {params}
+
+ <任务要求>:
+ 1.对于每个参数,提供清晰的定义和作用描述。
+ 2.参数的描述(desc)要求至少包括:1)这个参数的详细作用;2)可以缓解系统的哪一方面瓶颈,从CPU,disk IO,network,memory中选择相关的瓶颈给出;3)取值的描述,如果取值只有几个,分别描述每个取值的意义。如果取值为范围,则说明增大或减小该值的意义和作用。
+ 3.参数的最小值(min_value)和最大值(max_value)及默认值(default_value),请直接给出数值,不要出现计算,数值不需要引号。若参数取值为连续值,则参数的离散值(options)字段设置为null;若参数取值为离散值,请将最大值最小值设置为null,将参数离散值(options)字段设置为离散取值,例如参数取值为ON/OFF,则将options设置为["ON","OFF"]
+ 4.参数的互相关联参数(related_param)是与该参数相互影响的参数,一般需要同时调整配置。该字段请用列表输出,若没有,则输出一个空列表。
+ 5.使用准确的技术术语,并确保描述的准确性和可靠性。
+ 6.输出格式将每个参数的所有信息总结为一个json格式的知识。
+ 最终将结果以json格式输出,输出的json格式不要包含注释,参数的描述用中文输出,描述中的瓶颈用英文表示,其余字段用英文输出,输出示例格式为:
+ {example}
+ '''
+
+ example = json_example
+ messages = get_messages(role_prompt2,[],prompt.format(app=app, params=params, example=example))
+ chat_completion = client.chat.completions.create(
+ messages=messages,
+ model= llm_config.get('model', ''),
+ temperature=0.1
+ )
+
+ # 打印响应内容
+ print(chat_completion.choices[0].message.content)
+ ans = chat_completion.choices[0].message.content
+ ans = re.sub(r".*?", "", ans, flags=re.DOTALL)
+
+ json_pattern = r'\[.*\]'
+ json_str = re.search(json_pattern, ans, re.DOTALL)
+
+ if json_str:
+ # 提取匹配到的JSON字符串
+ json_str = json_str.group(0)
+ print(json_str)
+ return json_str
+ else:
+ print("没有找到JSON数据")
+ return
+
+def aggregate_web_result(text: str, param: str, app: str) -> str:
+ """
+ 将多个JSON格式的参数描述整合成一个完整的JSON对象。
+
+ Parameters:
+ text (str): 多个JSON格式的参数描述文本。
+ param (str): 参数名称。
+ app (str): 应用程序名称。
+
+ Returns:
+ str: 整合后的JSON字符串,若整合失败则返回None。
+ """
+ prompt = '''
+ 我有一些JSON格式的{app}参数结构化信息,这些JSON对象描述了同一个参数{param}的不同属性。这些JSON对象可能包含重复或部分信息,需要将它们整合成一个完整的JSON对象。
+ 目标:
+ 将所有描述同一参数的JSON对象整合成一个完整的JSON对象,合并重复字段,并确保每个字段的值是准确且完整的。
+ 要求:
+ 请根据输入,生成一个整合后的JSON对象,确保字段值完整,但是请不要添加你的知识,只根据提供的json对象填充。
+ 输入:以下是输入的json格式的参数信息
+ {text}
+ '''
+ messages = get_messages(role_prompt2,[],prompt.format(app=app, param=param, text=text))
+ chat_completion = client.chat.completions.create(
+ messages=messages,
+ model= llm_config.get('model', ''),
+ temperature=0.1
+ )
+ print(chat_completion.choices[0].message.content)
+ ans = chat_completion.choices[0].message.content
+ ans = re.sub(r".*?", "", ans, flags=re.DOTALL)
+ json_pattern = r'\{.*\}'
+ json_str = re.search(json_pattern, ans, re.DOTALL)
+
+ if json_str:
+ # 提取匹配到的JSON字符串
+ json_str = json_str.group(0)
+ return json_str
+ else:
+ print("没有找到JSON数据")
+ return
+
+def aggregate_result(param: str, official: str, web: str, gpt: str, app: str) -> str:
+ """
+ 汇总来自官方文档、web网页和GPT的参数信息,并整合成一个完整的JSON对象。
+
+ Parameters:
+ param (str): 参数名称。
+ official (str): 官方文档中的参数信息。
+ web (str): web网页中的参数信息。
+ gpt (str): GPT中的参数信息。
+ app (str): 应用程序名称。
+
+ Returns:
+ str: 整合后的JSON字符串,若整合失败则返回None。
+ """
+ prompt = '''
+ 我有一些JSON格式的{app}参数结构化信息,这些JSON对象描述了同一个参数{param},JSON信息的来源分别为官方文档,web网页和GPT。
+ 具体信息如下:
+ 官方文档:{official}
+ web网页:{web}
+ GPT:{gpt}
+
+ 请根据以下要求和提示,汇总并处理来自 官方文档、web网页 和 GPT 的信息,并确保最终的描述准确、完整且一致。输出与输入结构相同的 JSON 格式参数信息:
+
+ 1. 参数描述(desc)
+ 请综合官方文档、web网页和ChatGPT的描述,提取清晰、详细的参数功能描述。若来源中的描述有所不同,请优先参考GPT和官方文档中提供的详细说明,并与web网页中的实践建议进行对比,确保描述完整、详细且准确。如果有冲突,选择权威的描述。如无冲突,尽可能保留更多详细内容。描述最后总结为中文。
+
+ 2. 是否需要重启(needrestart)
+ 根据官方文档、web网页和ChatGPT提供的信息,判断该应用参数修改后是否需要重启{app}服务才能生效。若来源中存在不同的意见,优先参考官方文档和GPT来源中的内容。如果官方文档和GPT中的做法冲突,请重新分析是否需要重启并给出结果。
+
+ 3. 参数类型(type)
+ 请根据官方文档、web网页和ChatGPT提供的描述,确认该参数的类型。类型描述只包括`continuous`、`discrete`,优先参考官方文档中的类型定义,并与web网页中的使用实例进行对比,确保类型选择准确。如果不确定,请重新分析参数是离散还是连续。
+
+ 4. 最小值(min_value)
+ 请根据官方文档、web网页和ChatGPT提供的信息,确认该参数的最小值。若参数是离散的,该值设置为null。若来源中有不同的最小值,请优先参考官方文档中的说明,或者通过查看{app}官方文档来确认。如果web网页和ChatGPT中的值不同,确保选择的最小值符合实际环境配置需求。
+
+ 5. 最大值(max_value)
+ 根据官方文档、web网页和ChatGPT提供的信息,确认该参数的最大值。若参数是离散的,该值设置为null。如果不同来源给出的最大值有所不同,选择它们的交集或更具权威性的最大值。例如,如果官方文档明确给出了最大值范围,而web网页和ChatGPT提供的最大值偏大或偏小,请参照官方文档中的推荐值。
+
+ 6. 默认值(default_value)
+ 请根据官方文档、web网页和ChatGPT提供的信息,确认该参数的默认值。若不同来源提供的默认值不一致,请优先参考官方文档中的默认值。如果有多个来源提供相同的默认值,则采用该值作为最终结论。
+
+ 7. 数据类型(dtype)
+ 根据官方文档、web网页和ChatGPT提供的信息,确认该参数的数据类型(如`int`、`string`、`boolean`等)。如果不同来源对数据类型的定义不一致,请优先参考官方文档中的准确描述,确保数据类型与{app}的实际配置一致。如果有多个合理的选项,请选择最常见的类型并验证其准确性。
+
+ 8. 生效版本(version)
+ 根据官方文档、web网页和ChatGPT提供的信息,确认该参数的生效版本。请优先参考官方文档中的生效版本。如果有多个来源提供相同的生效版本,则采用该值作为最终结论。
+
+ 9.相关参数(related_param)
+ 根据官方文档、web网页和ChatGPT提供的信息,确认该参数的相关参数。该字段请将多个来源的值取并集输出为列表。
+
+ 10.参数的离散值(options)
+ 根据官方文档、web网页和ChatGPT提供的信息,确认该参数的离散取值。若参数是连续的,该值设置为null。若参数是离散的,请优先参考官方文档来源的内容,如果有多个来源提供相同的离散值,则采用该离散值为最终结论。
+
+ 注意:最终输出结构和输入结构相同,输出一个json格式的参数信息,请不要给出分析过程,只输出最后的json。
+
+ '''
+ messages = get_messages(role_prompt2,[],prompt.format(app=app, param=param, official=official, web=web, gpt=gpt))
+ chat_completion = client.chat.completions.create(
+ messages=messages,
+ model= llm_config.get('model', ''),
+ temperature=0.1
+ )
+ #print(chat_completion.choices[0].message.content)
+ ans = chat_completion.choices[0].message.content
+ ans = re.sub(r".*?", "", ans, flags=re.DOTALL)
+ json_pattern = r'\{.*\}'
+ json_str = re.search(json_pattern, ans, re.DOTALL)
+
+ if json_str:
+ # 提取匹配到的JSON字符串
+ json_str = json_str.group(0)
+ return json_str
+ else:
+ print("没有找到JSON数据")
+ return
+
+
+if __name__ == "__main__":
+ parameter_knowledge_preparation()
+
diff --git a/src/knowledge_extractor/app_config.yaml b/src/knowledge_extractor/app_config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4472b6f163e4ff71f737fc5e8417b872c2e9946f
--- /dev/null
+++ b/src/knowledge_extractor/app_config.yaml
@@ -0,0 +1,35 @@
+# ==============================================
+# 知识库构建配置文件
+# ==============================================
+
+# 应用程序信息
+app_name: mysql
+# 参数文件路径
+params_file: ../mysql_knowledge/params.txt
+# 保存路径
+save_path: ../mysql_knowledge/
+# 滑动窗口大小
+window_size: 5000
+# 官方文档 URL 列表
+official_url: ["https://mysql.net.cn/doc/refman/8.0/en/innodb-parameters.html"]
+# Web 网页 URL 列表
+web_url: ["https://www.cnblogs.com/kevingrace/p/6133818.html"]
+# 代码路径(用于源代码参数提取)
+code_path: ../code
+# 额外补充文件列表(支持 Excel、PDF 等格式)
+append_file_paths: ["sample1.xlsx", "sample2.xlsx", "score.xlsx"]
+
+# ==============================================
+# 步骤开关配置
+# ==============================================
+
+step1: true # 读取参数列表
+step2: true # 获取官方文档信息
+step3: true # 获取web网页信息
+step4: true # 官方文档参数提取
+step5: true # web信息参数提取
+step6: false # 源代码参数提取
+step7: true # GPT生成参数信息
+step8: false # 补充文件生成参数信息
+step9: true # 参数知识聚合
+step10: true # 参数汇总文件生成
\ No newline at end of file
diff --git a/src/knowledge_extractor/code_extractor.py b/src/knowledge_extractor/code_extractor.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b00c69fec1c2b3d65c65cb59b53841061e98945
--- /dev/null
+++ b/src/knowledge_extractor/code_extractor.py
@@ -0,0 +1,90 @@
+import os
+
+def extract_code_snippets(param_list: list, folder_path: str, output_file: str, context_lines: int = 20) -> None:
+ """
+ 提取指定文件夹中与给定参数列表相关的代码片段,并将结果写入到指定的输出文件中。
+
+ Parameters:
+ param_list (list): 参数列表,例如 ["spark.driver.memoryOverheadFactor", "spark.executor.memory"]
+ folder_path (str): 要查询的文件夹路径
+ output_file (str): 输出文件路径
+ context_lines (int, optional): 提取代码片段时包含的上下文行数。默认为20。
+
+ Returns:
+ None
+ """
+ results = {}
+
+ # 遍历文件夹中的所有文件
+ for root, _, files in os.walk(folder_path):
+ for file_name in files:
+ file_path = os.path.join(root, file_name)
+ try:
+ with open(file_path, 'r', encoding='utf-8') as file:
+ lines = file.readlines()
+ except Exception as e:
+ print(f"无法读取文件 {file_path}: {e}")
+ continue
+
+ # 检查文件内容
+ related_ranges = []
+ for i, line in enumerate(lines):
+ # 检查当前行是否包含参数
+ for param in param_list:
+ if param in line:
+ # 计算上下文范围
+ start = max(0, i - context_lines)
+ end = min(len(lines), i + context_lines + 1)
+ related_ranges.append((start, end))
+ break
+
+
+ # 合并重叠的范围
+ if related_ranges:
+ # 按起始行排序
+ print(related_ranges)
+ related_ranges.sort(key=lambda x: x[0])
+ merged_ranges = []
+ for current_start, current_end in related_ranges:
+ if not merged_ranges:
+ merged_ranges.append((current_start, current_end))
+ else:
+ last_start, last_end = merged_ranges[-1]
+ # 如果当前范围与最后一个合并范围有重叠或相邻,则合并
+ if current_start <= last_end:
+ merged_ranges[-1] = (last_start, max(last_end, current_end))
+ else:
+ merged_ranges.append((current_start, current_end))
+
+ #print(merged_ranges)
+ # 提取合并后的代码片段
+ related_lines = []
+ for start, end in merged_ranges:
+ related_lines.extend(lines[start:end])
+ related_lines.extend("-" * 50+ "\n" )
+
+ unique_lines = []
+ for line in related_lines:
+ unique_lines.append(line)
+ results[file_path] = unique_lines
+
+ # 将结果写入到输出文件中
+ with open(output_file, 'w', encoding='utf-8') as outfile:
+ for file_path, lines in results.items():
+ for line in lines:
+ outfile.write(line)
+
+
+# 示例用法
+if __name__ == "__main__":
+ # 示例参数列表
+ params = ["spark.dynamicAllocation.minExecutors","spark.dynamicAllocation.maxExecutors","spark.dynamicAllocation.initialExecutors","spark.shuffle.service.db.enabled"]
+
+ # 示例文件夹路径
+ folder = "../code"
+
+ # 输出文件路径
+ output = "./output.txt"
+
+ # 提取代码片段并写入到文件
+ extract_code_snippets(params, folder, output)
\ No newline at end of file
diff --git a/src/knowledge_extractor/document_loaders.py b/src/knowledge_extractor/document_loaders.py
new file mode 100644
index 0000000000000000000000000000000000000000..600b649a02afaca26ae8c2feb60be5fc1389ad8c
--- /dev/null
+++ b/src/knowledge_extractor/document_loaders.py
@@ -0,0 +1,171 @@
+import pandas as pd
+import numpy as np
+import re
+import fitz # PyMuPDF
+from docx import Document
+
+# 判断是否为数字
+def is_number(s):
+ try:
+ float(s)
+ return True
+ except ValueError:
+ return False
+
+# 判断是否为英文
+def is_english(s):
+ return bool(re.match(r'^[a-zA-Z0-9_]+$', str(s)))
+
+# 判断是否为中文
+def is_chinese(s):
+ return bool(re.match(r'^[\u4e00-\u9fff]+$', str(s)))
+
+# 分析数据类型
+def analyze_data_types(df):
+ row_types = []
+ col_types = []
+
+ # 分析每行的数据类型
+ for i in range(len(df)):
+ row_type = set()
+ for value in df.iloc[i]:
+ if pd.isna(value):
+ continue
+ if is_number(value):
+ row_type.add('number')
+ elif is_english(value):
+ row_type.add('english')
+ elif is_chinese(value):
+ row_type.add('chinese')
+ else:
+ row_type.add('string')
+ row_types.append(row_type)
+
+ # 分析每列的数据类型
+ for j in range(len(df.columns)):
+ col_type = set()
+ for value in df.iloc[:, j]:
+ if pd.isna(value):
+ continue
+ if is_number(value):
+ col_type.add('number')
+ elif is_english(value):
+ col_type.add('english')
+ elif is_chinese(value):
+ col_type.add('chinese')
+ else:
+ col_type.add('string')
+ col_types.append(col_type)
+
+ return row_types, col_types
+
+# 计算行之间的相似度
+def calculate_row_similarity(row_types):
+ similarities = []
+ for i in range(len(row_types)):
+ row_similarities = []
+ for j in range(len(row_types)):
+ if i != j:
+ sim = len(row_types[i].intersection(row_types[j])) / len(row_types[i].union(row_types[j]))
+ row_similarities.append(sim)
+ else:
+ row_similarities.append(1) # 自身相似度为1
+ similarities.append(row_similarities)
+ return similarities
+
+# 计算列之间的相似度
+def calculate_column_similarity(col_types):
+ similarities = []
+ for i in range(len(col_types)):
+ col_similarities = []
+ for j in range(len(col_types)):
+ if i != j:
+ sim = len(col_types[i].intersection(col_types[j])) / len(col_types[i].union(col_types[j]))
+ col_similarities.append(sim)
+ else:
+ col_similarities.append(1) # 自身相似度为1
+ similarities.append(col_similarities)
+ return similarities
+
+# 判断数据排列方式并输出结构化数据到文本文件
+def excel2text(file_path, sheet_name='Sheet1', output_file='../mysql_knowledge/temp/excel_out.txt'):
+ # 读取Excel文件
+ df = pd.read_excel(file_path, sheet_name=sheet_name, header=None)
+
+ # 分析数据类型
+ row_types, col_types = analyze_data_types(df)
+
+ # 计算行之间的相似度
+ row_similarities = calculate_row_similarity(row_types)
+ row_avg_similarity = np.mean(row_similarities)
+
+ # 计算列之间的相似度
+ col_similarities = calculate_column_similarity(col_types)
+ col_avg_similarity = np.mean(col_similarities)
+
+ # 判断排列方式
+ if row_avg_similarity > col_avg_similarity:
+ orientation = "横向分布"
+ else:
+ orientation = "纵向分布"
+
+ # 打开输出文件
+ with open(output_file, 'w', encoding='utf-8') as f:
+ # 根据排列方式输出结构化数据
+ if orientation == "横向分布":
+ # 读取Excel文件,指定标题行
+ df = pd.read_excel(file_path, sheet_name=sheet_name, header=0)
+ for index, row in df.iterrows():
+ for column in df.columns:
+ f.write(f"{column}: {row[column]}\n")
+ f.write("-" * 40 + "\n")
+ else:
+ # 读取Excel文件,不指定标题行
+ df = pd.read_excel(file_path, sheet_name=sheet_name, header=None)
+ df_transposed = df.T
+ for index, row in df_transposed.iterrows():
+ for column in df_transposed.columns:
+ f.write(f"{row[0]}: {row[column]}\n")
+ f.write("-" * 40 + "\n")
+
+def pdf2text(pdf_path, output_file='../mysql_knowledge/temp/pdf_out.txt'):
+ # 打开PDF文件
+ doc = fitz.open(pdf_path)
+
+ # 打开输出文件
+ with open(output_file, 'w', encoding='utf-8') as f:
+ # 遍历PDF的每一页
+ for page_number in range(len(doc)):
+ page = doc.load_page(page_number) # 加载当前页
+ page_text = page.get_text() # 提取当前页的文本
+
+ # 将当前页的文本写入文件
+ f.write(f"Page {page_number + 1}:\n")
+ f.write(page_text)
+ f.write("\n\n") # 添加一个空行分隔各页的文本
+
+ # 关闭PDF文件
+ doc.close()
+
+
+def docx2text(docx_path, output_file='../mysql_knowledge/temp/docx_out.txt'):
+ """
+ 提取 .docx 文件中的文本内容并保存为 .txt 文件。
+
+ 参数:
+ docx_path (str): 输入的 .docx 文件路径。
+ output_txt_path (str): 输出的 .txt 文件路径。
+ """
+ try:
+ # 打开 .docx 文件
+ doc = Document(docx_path)
+
+ # 提取所有段落的文本
+ text = "\n".join([para.text for para in doc.paragraphs])
+
+ # 将文本写入 .txt 文件
+ with open(output_file, "w", encoding="utf-8") as txt_file:
+ txt_file.write(text)
+ except Exception as e:
+ print(f"Error processing the file: {e}")
+
diff --git a/src/knowledge_extractor/llm_config.yaml b/src/knowledge_extractor/llm_config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..52bb6e233bc89e3841727def1b3e66582885052e
--- /dev/null
+++ b/src/knowledge_extractor/llm_config.yaml
@@ -0,0 +1,4 @@
+llm:
+ model: gpt-4o-mini #'deepseek-r1:14b' 'qwen2:72b'
+ api_key:
+ base_url:
\ No newline at end of file
diff --git a/src/knowledge_extractor/main.py b/src/knowledge_extractor/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2ee080d5a316c2bb11c709c07ed5251042bc948
--- /dev/null
+++ b/src/knowledge_extractor/main.py
@@ -0,0 +1,356 @@
+from text_split import *
+from web_crawler import *
+from api_server import *
+from save_knowledge import *
+from code_extractor import *
+from document_loaders import *
+from merge_files import *
+from output import *
+import yaml
+import os
+import time
+
+def load_config(file_path: str) -> dict:
+ """
+ 加载配置文件。
+
+ Parameters:
+ file_path (str): 配置文件路径。
+
+ Returns:
+ dict: 配置信息。
+ """
+ with open(file_path, 'r', encoding='utf-8') as file:
+ config = yaml.safe_load(file)
+ return config
+
+
+def pipeline(param: str, app: str, config: dict) -> None:
+ """
+ 处理多来源参数知识并汇总生成知识库内容。
+
+ Parameters:
+ param (str): 参数名称。
+ app (str): 应用程序名称。
+ config (dict): 配置信息。
+
+ Returns:
+ None
+ """
+ save_path = config.get("save_path", "../{}_knowledge/".format(app)) # 存储文件的路径
+ output_dir = save_path + "summary/"
+ if not os.path.exists(output_dir):
+ os.makedirs(output_dir)
+ official_doc_path = save_path + "official/" + param + ".json"
+ gpt_suggestion_path = save_path + "gpt/" + param + ".json"
+ web_suggestion_path = save_path + "web/" + param + ".json"
+ official_doc = ""
+ gpt_suggestion = ""
+ web_suggestion = ""
+
+ try:
+ with open(official_doc_path, 'r', encoding='utf-8') as file:
+ official_doc = file.read()
+ print(official_doc)
+ except FileNotFoundError:
+ print("官方文档文件未找到,已设置为空。")
+
+ try:
+ with open(gpt_suggestion_path, 'r', encoding='utf-8') as file:
+ gpt_suggestion = file.read()
+ print(gpt_suggestion)
+ except FileNotFoundError:
+ print("GPT建议文件未找到,已设置为空。")
+
+ try:
+ with open(web_suggestion_path, 'r', encoding='utf-8') as file:
+ web_suggestion = file.read()
+ print(web_suggestion)
+ except FileNotFoundError:
+ print("WEB建议文件未找到,已设置为空。")
+
+ # 总结官方文档及gpt和web建议
+ sources_json = aggregate_result(param, official_doc, web_suggestion, gpt_suggestion, app)
+
+ try:
+ json_data = json.loads(sources_json)
+ except json.JSONDecodeError as e:
+ print(f"JSON解析错误:{e}")
+ with open(save_path + "summary/" + param + ".json", 'w', encoding='utf-8') as json_file:
+ json.dump(json_data, json_file, indent=4, ensure_ascii=False)
+
+
+def main(config_file_path: str) -> None:
+ """
+ 主函数,负责执行整个知识库生成流程。
+
+ Parameters:
+ config_file_path (str): 配置文件路径。
+
+ Returns:
+ None
+ """
+ config = load_config(config_file_path)
+ app = config.get("app_name", "mysql")
+ params_file = config.get("params_file", "")
+ save_path = config.get("save_path", "../{}_knowledge/".format(app))
+
+ start_time = time.perf_counter()
+
+ STEP = "====== {} ======"
+
+ # step 1: 参数列表读取
+ if config.get("step1", False):
+ print(STEP.format("步骤 1: 读取参数列表"))
+ if os.path.exists(params_file):
+ with open(params_file, 'r', encoding='utf-8') as f:
+ params = f.read()
+ params = [param.strip('"') for param in params.split()]
+ print(params)
+ time.sleep(2)
+ else:
+ print("提示:参数文件不存在,将提取官方文档中所有参数")
+ params = []
+ else:
+ print(STEP.format("步骤 1: 跳过"))
+
+ # step 2: official信息读取并存储文本
+ if config.get("step2", False):
+ print(STEP.format("步骤 2: official信息读取并存储"))
+ urls = config.get("official_url",[])
+ output_file = save_path + "official_text.txt"
+ for url in urls:
+ print(url)
+ convert_html_to_markdown(url,output_file)
+ time.sleep(2)
+ else:
+ print(STEP.format("步骤 2: 跳过"))
+
+ # step 3: web信息读取并存储文本
+ if config.get("step3", False):
+ print(STEP.format("步骤 3: web信息读取并存储"))
+ web_list = config.get("web_url",[])
+ for web in web_list:
+ output_file = save_path + "web_text.txt"
+ convert_html_to_markdown(web,output_file)
+ time.sleep(2)
+ else:
+ print(STEP.format("步骤 3: 跳过"))
+
+ # step 4: official信息转化为结构化数据(所有参数)
+ if config.get("step4", False):
+ print(STEP.format("步骤 4: official信息转化为结构化数据"))
+ # 进行滑窗划分和gpt提取知识,生成结构化数据
+ with open(save_path+"official_text.txt", 'r', encoding='utf-8') as f:
+ text = f.read()
+ #找到最长的段落,将滑窗的滑动的步长设置为滑窗长度-longtext
+ long_text = find_longest_paragraph_length(text)
+ if long_text > 1000:
+ long_text = 1000
+ elif long_text < 300:
+ long_text = 300
+ window_size = config.get("window_size",5000)
+ step = window_size - long_text
+ # official信息的分割结果
+ segments = sliding_window_split(text[0:3000], window_size, step)
+ for segment in segments:
+ ans = parameter_official_knowledge_preparation(segment, app)
+ split_json_to_files(ans, save_path+"official")
+ time.sleep(2)
+ else:
+ print(STEP.format("步骤 4: 跳过"))
+ if len(params)==0:
+ for filename in os.listdir(save_path+"official"):
+ if filename.endswith(".json"):
+ params.append(filename[:-5])
+ print(params)
+
+ # step 5: web信息转化为分条的参数信息(根据给定参数列表)
+ if config.get("step5", False):
+ print(STEP.format("步骤 5: web信息转化为结构化数据"))
+ # web信息提取 信息已存在web_text.txt中
+ with open(save_path + "web_text.txt", 'r', encoding='utf-8') as f:
+ text = f.read()
+ if not os.path.exists(save_path + "web/"):
+ os.makedirs(save_path + "web/")
+ #找到最长的段落,将滑窗的滑动的步长设置为滑窗长度-longtext
+ long_text = find_longest_paragraph_length(text)
+ if long_text > 1000:
+ long_text = 1000
+ elif long_text < 300:
+ long_text = 300
+ window_size = config.get("window_size",5000)
+ step = window_size - long_text
+ segments = sliding_window_split(text[0:3000], window_size, step)
+ ans_list = []
+ for segment in segments:
+ ans = parameter_knowledge_preparation(segment, params, app)
+ if ans is None:
+ continue
+ try:
+ ans_list.append(json.loads(ans))
+ except json.JSONDecodeError as e:
+ print(f"JSON解析错误:{e}")
+ for param in params:
+ web_text = ""
+ for ans in ans_list:
+ for item in ans:
+ if item["name"] == param:
+ web_text += str(item) + "\n"
+ if web_text != "":
+ ans = aggregate_web_result(web_text,param,app)
+ try:
+ json_data = json.loads(ans)
+ except json.JSONDecodeError as e:
+ print(f"JSON解析错误:{e}")
+ with open(save_path + "web/"+param+".json", 'w', encoding='utf-8') as json_file:
+ json.dump(json_data, json_file, indent=4, ensure_ascii=False)
+ time.sleep(2)
+ else:
+ print(STEP.format("步骤 5: 跳过"))
+
+ # step 6: 从源代码中获取参数知识
+ if config.get("step6", False):
+ print(STEP.format("步骤 6: 从源代码中获取参数知识"))
+ folder = config.get("code_path","../code")
+ out_file = save_path + "code_text.txt"
+ # 提取代码片段
+ extract_code_snippets(params, folder,out_file)
+ # code信息已存在code_text.txt中
+ with open(save_path + "code_text.txt", 'r', encoding='utf-8') as f:
+ text = f.read()
+ if not os.path.exists(save_path + "code/"):
+ os.makedirs(save_path + "code/")
+ #找到最长的段落,将滑窗的滑动的步长设置为滑窗长度-longtext
+ long_text = find_longest_paragraph_length(text)
+ if long_text > 1000:
+ long_text = 1000
+ elif long_text < 300:
+ long_text = 300
+ window_size = config.get("window_size",5000)
+ step = window_size - long_text
+ segments = sliding_window_split(text, window_size, step)
+ ans_list = []
+ for segment in segments:
+ ans = parameter_knowledge_preparation(segment, params, app)
+ if ans is None:
+ continue
+ try:
+ ans_list.append(json.loads(ans))
+ except json.JSONDecodeError as e:
+ print(f"JSON解析错误:{e}")
+ for param in params:
+ web_text = ""
+ for ans in ans_list:
+ for item in ans:
+ if item["name"] == param:
+ web_text += str(item) + "\n"
+ if web_text != "":
+ ans = aggregate_web_result(web_text,param,app)
+ try:
+ json_data = json.loads(ans)
+ except json.JSONDecodeError as e:
+ print(f"JSON解析错误:{e}")
+ with open(save_path + "code/"+param+".json", 'w', encoding='utf-8') as json_file:
+ json.dump(json_data, json_file, indent=4, ensure_ascii=False)
+ time.sleep(2)
+ else:
+ print(STEP.format("步骤 6: 跳过"))
+
+ # step 7: GPT直接生成结构化数据
+ if config.get("step7", False):
+ print(STEP.format("步骤 7: GPT直接生成结构化数据"))
+ # gpt数据获取
+ batch_size = 15 # 每批次处理15个元素
+ for i in range(0, len(params), batch_size): # 按批次循环处理
+ batch_params = params[i:i+batch_size] # 提取当前批次
+ gpt_data = get_gpt_knowledge(batch_params,app)
+ split_json_to_files(gpt_data, save_path + "gpt")
+ time.sleep(2)
+ else:
+ print(STEP.format("步骤 7: 跳过"))
+
+ # step 8: 通过补充文件作为补充的参数信息输入。
+ if config.get("step8", False):
+ print(STEP.format("步骤 8: 通过补充文件作为补充的参数信息输入"))
+ append_file_paths = config.get("append_file_paths",[])
+ # 根据文件类型进行处理
+ for file_path in append_file_paths:
+ # 获取文件扩展名并转换为小写
+ _, file_extension = os.path.splitext(file_path)
+ file_extension = file_extension.lower()
+
+ # 判断文件类型
+ if file_extension == ".pdf":
+ pdf2text(file_path, save_path+"temp/pdf_out.txt")
+ with open(save_path+"temp/pdf_out.txt", 'r', encoding='utf-8') as f:
+ text = f.read()
+ elif file_extension == ".docx":
+ docx2text(file_path, save_path+"temp/docx_out.txt")
+ with open(save_path+"temp/docx_out.txt", 'r', encoding='utf-8') as f:
+ text = f.read()
+ elif file_extension in [".xlsx", ".xls"]:
+ excel2text(file_path, "Sheet1", save_path+"temp/excel_out.txt")
+ with open(save_path+"temp/excel_out.txt", 'r', encoding='utf-8') as f:
+ text = f.read()
+ else:
+ print("Unsupported File Type")
+ continue
+
+ if(len(text) < config.get("window_size",5000)):
+ ans = parameter_official_knowledge_preparation(segment)
+ split_json_to_files(ans, save_path+"addition")
+ else:
+ long_text = find_longest_paragraph_length(text)
+ if long_text > 1000:
+ long_text = 1000
+ window_size = config.get("window_size",5000)
+ step = window_size - long_text
+ # 文档补充信息的分割结果
+ segments = sliding_window_split(text, window_size, step)
+ for segment in segments:
+ ans = parameter_official_knowledge_preparation(segment)
+ split_json_to_files(ans, save_path+"addition")
+ time.sleep(2)
+ else:
+ print(STEP.format("步骤 8: 跳过"))
+
+ # step 9: 结构化数据的信息聚合
+ if config.get("step9", False):
+ print(STEP.format("步骤 9: 结构化数据的信息聚合"))
+ time.sleep(2)
+ for param in params:
+ pipeline(param, app, config)
+ else:
+ print(STEP.format("步骤 9: 跳过"))
+
+ # step 10: 生成参数知识库总文件
+ if config.get("step10", False):
+ print(STEP.format("步骤 10: 生成参数知识库总文件"))
+ time.sleep(2)
+ # # 合并json
+ input_directory = save_path+"summary" # 替换为包含JSON文件的目录路径
+ output_file = save_path+ app+ "_param.json" # 替换为输出文件的路径
+ merge_json_files(input_directory, output_file)
+ output_file_json = save_path+ app+ ".json" # 替换为输出文件的路径
+ process_json(output_file, output_file_json)
+
+ # # json转换为jsonl
+ # input_file = save_path+ app+ "_param.json"
+ # output_file = save_path+ app+ "_param.jsonl"
+ # json_to_jsonl(input_file, output_file)
+ else:
+ print(STEP.format("步骤 10: 跳过"))
+
+
+ end_time = time.perf_counter()
+ execution_time = end_time - start_time
+ print(f"生成的参数知识库内容:{params}")
+ print("-----------------------")
+ print(f"程序执行时间:{execution_time} 秒")
+ print(f"一共生成了参数知识:{len(params)} 条")
+
+if __name__ == "__main__":
+
+ config_file_path = "app_config.yaml"
+ main(config_file_path)
\ No newline at end of file
diff --git a/src/knowledge_extractor/merge_files.py b/src/knowledge_extractor/merge_files.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e9e10d4e48fa0018e6b5b0327c3e19e72bd6725
--- /dev/null
+++ b/src/knowledge_extractor/merge_files.py
@@ -0,0 +1,46 @@
+import json
+import os
+
+def merge_json_files(input_directory, output_file):
+ """
+ Merge multiple JSON files into a single JSON file.
+
+ :param input_directory: Directory containing JSON files to be merged.
+ :param output_file: Path to the output JSON file.
+ """
+ merged_data = []
+
+ # 遍历输入目录中的所有文件
+ for filename in os.listdir(input_directory):
+ if filename.endswith(".json"):
+ file_path = os.path.join(input_directory, filename)
+ try:
+ # 打开并读取JSON文件
+ with open(file_path, 'r', encoding='utf-8') as file:
+ data = json.load(file)
+ # 将内容添加到合并列表中
+ merged_data.append(data)
+ except json.JSONDecodeError as e:
+ print(f"Error reading {file_path}: {e}")
+ except Exception as e:
+ print(f"An error occurred: {e}")
+
+ # 将合并后的数据写入输出文件
+ with open(output_file, 'w', encoding='utf-8') as outfile:
+ json.dump(merged_data, outfile, indent=4, ensure_ascii=False)
+
+ print(f"Merged data has been written to {output_file}")
+
+def json_to_jsonl(input_file, output_file):
+ """
+ 将 JSON 文件转换为 JSON Lines 文件
+ """
+ # 读取 JSON 文件
+ with open(input_file, 'r', encoding='utf-8') as f:
+ data = json.load(f)
+
+ # 将列表中的每个对象写入 JSON Lines 文件
+ with open(output_file, 'w', encoding='utf-8') as f:
+ for item in data:
+ json.dump(item, f, ensure_ascii=False)
+ f.write('\n') # 每个 JSON 对象占一行
\ No newline at end of file
diff --git a/src/knowledge_extractor/output.py b/src/knowledge_extractor/output.py
new file mode 100644
index 0000000000000000000000000000000000000000..0347b9de1478d8030f870f1871d57c499ec269a7
--- /dev/null
+++ b/src/knowledge_extractor/output.py
@@ -0,0 +1,58 @@
+import json
+
+def process_data(data):
+ processed_data = {}
+ for item in data:
+ param_name = item["name"]
+ param_info = item["info"]
+
+ processed_data[param_name] = {}
+ processed_data[param_name]["desc"] = param_info["desc"]
+ processed_data[param_name]["type"] = param_info["type"]
+ processed_data[param_name]["dtype"] = param_info["dtype"]
+ processed_data[param_name]["range"] = param_info["min_value"] if isinstance(param_info["min_value"], list) else [param_info["min_value"], param_info["max_value"]]
+ return processed_data
+
+def process_data1(data):
+ processed_data = {}
+ for item in data:
+ param_name = item["name"]
+ param_info = item["info"]
+
+ processed_data[param_name] = {}
+ processed_data[param_name]["desc"] = param_info["desc"]
+ processed_data[param_name]["type"] = param_info["type"]
+ processed_data[param_name]["dtype"] = param_info["dtype"]
+ processed_data[param_name]["range"] = param_info["options"] if param_info["type"] == "discrete" else [param_info["min_value"], param_info["max_value"]]
+ return processed_data
+
+def process_json(input_path, output_path):
+ """
+ 读取input_path指定的JSON文件,对数据进行处理,然后将处理后的数据保存到output_path指定的文件中。
+ """
+ try:
+ # 读取JSON文件
+ with open(input_path, 'r', encoding='utf-8') as infile:
+ data = json.load(infile)
+
+ processed_data = process_data1(data)
+
+ # 将处理后的数据保存到output_path
+ with open(output_path, 'w', encoding='utf-8') as outfile:
+ json.dump(processed_data, outfile, ensure_ascii=False, indent=4)
+
+ print(f"处理完成,结果已保存到 {output_path}")
+
+ except FileNotFoundError:
+ print(f"错误:文件 {input_path} 未找到。")
+ except json.JSONDecodeError:
+ print(f"错误:文件 {input_path} 不是有效的JSON格式。")
+ except Exception as e:
+ print(f"发生错误:{e}")
+
+
+# 示例调用
+if __name__ == "__main__":
+ input_path = "../mysql_knowledge/spark_param.json" # 输入文件路径
+ output_path = "../mysql_knowledge/spark.json" # 输出文件路径
+ process_json(input_path, output_path)
diff --git a/src/knowledge_extractor/requirements.txt b/src/knowledge_extractor/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..2a69d10312b6df5a4810462a165dd0e5cc71ed22
--- /dev/null
+++ b/src/knowledge_extractor/requirements.txt
@@ -0,0 +1,11 @@
+PyMuPDF==1.25.2
+html2text==2024.2.26
+httpx==0.28.1
+jieba==0.42.1
+numpy==2.3.2
+openai==1.98.0
+pandas==2.3.1
+python-docx==1.1.2
+PyYAML==6.0.2
+readability_lxml==0.8.1
+requests==2.32.4
diff --git a/src/knowledge_extractor/save_knowledge.py b/src/knowledge_extractor/save_knowledge.py
new file mode 100644
index 0000000000000000000000000000000000000000..722dd0ca1ac066002901c4333e9ae5dccd4e51f3
--- /dev/null
+++ b/src/knowledge_extractor/save_knowledge.py
@@ -0,0 +1,53 @@
+import json
+import re
+import os
+
+def split_json_to_files(json_text: str, output_dir: str) -> None:
+ """
+ 将包含多个参数的JSON文本分割并存储到不同的JSON文件中。
+ 如果文件已存在,比较新旧内容,选择内容较多的那个进行保存。
+
+ Parameters:
+ json_text (str): 包含多个参数的JSON文本。
+ output_dir (str): 输出文件的目录。
+
+ Returns:
+ None
+ """
+ # 确保输出目录存在
+ if not os.path.exists(output_dir):
+ os.makedirs(output_dir)
+
+ # 解析JSON文本
+ if json_text is None:
+ return
+ try:
+ json_data = json.loads(json_text)
+ except json.JSONDecodeError as e:
+ print(f"JSON解析错误:{e}")
+ return
+
+ # 遍历JSON对象并为每个参数创建或更新JSON文件
+ for param in json_data:
+ # 创建文件名
+ if "/" in param['name'] or "\\" in param['name']:
+ continue
+ file_name = f"{param['name']}.json"
+ file_path = os.path.join(output_dir, file_name)
+
+ # 读取文件现有内容(如果存在)
+ existing_content = ""
+ if os.path.exists(file_path):
+ with open(file_path, 'r', encoding='utf-8') as file:
+ existing_content = file.read()
+
+ # 将参数转换为JSON字符串
+ new_content = json.dumps(param, ensure_ascii=False, indent=4)
+
+ # 比较新旧内容长度,选择较长的内容进行保存
+ if len(new_content) > len(existing_content):
+ with open(file_path, 'w', encoding='utf-8') as file:
+ file.write(new_content)
+ print(f"已将参数 {param['name']} 更新到文件 {file_path}")
+ else:
+ print(f"文件 {file_path} 已存在且内容较多,未更新")
\ No newline at end of file
diff --git a/src/knowledge_extractor/text_split.py b/src/knowledge_extractor/text_split.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e602989c6aad9c317dd4ba87b962ecf538b527e
--- /dev/null
+++ b/src/knowledge_extractor/text_split.py
@@ -0,0 +1,161 @@
+import jieba
+import re
+
+def is_chinese(text):
+ """
+ 判断文本是否为中文(如果包含非中文字符则返回False)
+
+ Parameters:
+ text (str): 输入文本
+
+ Returns:
+ bool: 如果文本全为中文字符,返回True;否则返回False
+ """
+ return all('\u4e00' <= char <= '\u9fff' for char in text)
+
+
+def split_text_into_segments(text: str, max_length: int = 5000) -> list:
+ """
+ 将文本分割成多个段落,每个段落不超过max_length字符,并尽量保留语义完整性。
+ 中文段落使用jieba分词,英文段落直接输出。
+
+ Parameters:
+ text (str): 输入文本
+ max_length (int): 每个段落的最大字符数,默认为5000
+
+ Returns:
+ list: 分割后的文本段落列表
+ """
+ paragraphs = text.split('\n') # 按行分段
+ segments = []
+ current_segment = ""
+
+ for para in paragraphs:
+ if len(current_segment) + len(para) + 1 <= max_length:
+ # 如果当前段落加上新段落不超过最大字符数,就加入当前段落
+ if current_segment:
+ current_segment += '\n' + para
+ else:
+ current_segment = para
+ else:
+ # 当前段落超过限制,保存并重新开始一个新的段落
+ segments.append(current_segment)
+ current_segment = para
+
+ # 添加最后一个段落
+ if current_segment:
+ segments.append(current_segment)
+
+ return segments
+
+
+def tokenize_text(text: str) -> str:
+ """
+ 对中文文本进行分词,英文文本保持原样
+
+ Parameters:
+ text (str): 输入文本
+
+ Returns:
+ str: 分词后的文本(中文)或原样文本(英文)
+ """
+ if is_chinese(text):
+ # 中文段落进行jieba分词
+ words = jieba.cut(text)
+ return " ".join(words)
+ else:
+ # 英文段落保持原样输出
+ return text
+
+
+def process_text_file(input_file: str, output_prefix: str, max_length: int = 5000):
+ """
+ 处理文本文件,进行分段和分词,并输出为多个文件。
+ 中文段落分词,英文段落保持原样输出
+
+ Parameters:
+ input_file (str): 输入文件路径
+ output_prefix (str): 输出文件前缀
+ max_length (int): 每个段落的最大字符数,默认为5000
+ """
+ # 读取输入文件
+ with open(input_file, 'r', encoding='utf-8') as f:
+ text = f.read()
+
+ # 将文本分割成多个段落
+ segments = split_text_into_segments(text, max_length)
+
+ # 输出分词后的段落到多个文件
+ for i, segment in enumerate(segments):
+ tokenized_text = tokenize_text(segment)
+ output_file = f"{output_prefix}_{i + 1}.txt"
+
+ with open(output_file, 'w', encoding='utf-8') as f:
+ f.write(tokenized_text)
+ print(f"文件 {output_file} 已创建,包含 {len(segment)} 字符。")
+
+
+def sliding_window_split(text: str, window_size: int, step: int) -> list:
+ """
+ 使用滑动窗口对文本进行分割。
+
+ Parameters:
+ text (str): 待分割的文本
+ window_size (int): 窗口大小,即每个分割片段的字符数
+ step (int): 滑动窗口的步长
+
+ Returns:
+ list: 分割后的文本片段列表
+
+ Raises:
+ ValueError: 如果窗口大小或步长不大于0,或者步长大于窗口大小
+ """
+ if window_size <= 0 or step <= 0:
+ raise ValueError("窗口大小和步长必须为正整数")
+
+ if window_size > len(text):
+ return [text] # 或者返回空列表 []
+
+ if step > window_size:
+ raise ValueError("步长不能大于窗口大小")
+
+ # 初始化一个空列表来存储分割后的文本片段
+ segments = []
+
+ # 使用滑动窗口进行文本分割
+ for i in range(0, len(text) - window_size + 1, step):
+ # 提取当前窗口内的文本片段
+ segment = text[i:i + window_size]
+ # 将文本片段添加到列表中
+ segments.append(segment)
+
+ # 检查是否还有剩余的文本
+ if len(text) % step != 0:
+ segments.append(text[-window_size:])
+
+ return segments
+
+def find_longest_paragraph_length(text: str, delimiter: str = '\n') -> int:
+ """
+ 找出文本中最长段落的长度
+
+ Parameters:
+ text (str): 输入文本
+ delimiter (str): 段落分隔符,默认为换行符
+
+ Returns:
+ int: 最长段落的长度
+ """
+ # 根据段落分隔符分割文本
+ if delimiter == '\n':
+ paragraphs = text.split(delimiter)
+ else:
+ paragraphs = re.split(delimiter, text)
+
+ # 计算每个段落的长度
+ paragraph_lengths = [len(paragraph) for paragraph in paragraphs]
+
+ # 找出最长段落的长度
+ max_paragraph_length = max(paragraph_lengths)
+
+ return max_paragraph_length
\ No newline at end of file
diff --git a/src/knowledge_extractor/web_crawler.py b/src/knowledge_extractor/web_crawler.py
new file mode 100644
index 0000000000000000000000000000000000000000..23a3d14cacc2b4c9e6e15f7c6b266a86e57f85ac
--- /dev/null
+++ b/src/knowledge_extractor/web_crawler.py
@@ -0,0 +1,55 @@
+import requests
+import html2text
+from readability import Document
+
+
+def convert_html_to_markdown(url: str, output_file_path: str) -> None:
+ """
+ 将指定URL的HTML网页内容转换为Markdown,并保存到指定的文件中。
+
+ Parameters:
+ url (str): 目标网页的URL。
+ output_file_path (str): 保存Markdown内容的文件路径。
+
+ Returns:
+ None
+ """
+ headers = {
+ 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/128.0.0.0 Safari/537.36 Edg/128.0.0.0'
+ }
+ # 发送HTTP GET请求获取网页内容
+ response = requests.get(url, headers=headers, verify=False)
+ print(response.text)
+
+ # 检查请求是否成功
+ if response.status_code == 200:
+ # 获取HTML内容
+ html_content = response.text
+ doc = Document(html_content)
+ content = doc.summary(html_partial=False)
+
+ content = response.text
+ # 创建html2text对象
+ h = html2text.HTML2Text()
+
+ # 配置转换器(可选)
+ h.ignore_links = True # 是否忽略链接
+ h.ignore_images = True # 是否忽略图片
+ h.ignore_emphasis = True # 是否忽略强调(如斜体、粗体)
+
+ # 转换HTML为Markdown
+ markdown_content = h.handle(content)
+
+ # 打印Markdown内容
+ print(markdown_content)
+ # 将Markdown内容保存到文件中
+ with open(output_file_path, 'a', encoding='utf-8') as file:
+ file.write(markdown_content)
+
+ return markdown_content
+
+ else:
+ print(f"请求失败,状态码:{response.status_code}")
+ print("请检查网页链接的合法性,并适当重试。")
+
+