From bb0eeeafbedc0f204abecd76f5ae5ef4b978fb99 Mon Sep 17 00:00:00 2001 From: "shanhaikang.shk" Date: Tue, 26 Nov 2024 21:03:46 +0800 Subject: [PATCH 1/3] basically finish doc -> graph Signed-off-by: shanhaikang.shk --- llm_extractor.py | 44 ++++++ load_docs_for_graph_rag.py | 314 +++++++++++++++++++++++++++++++++++++ 2 files changed, 358 insertions(+) create mode 100644 llm_extractor.py create mode 100644 load_docs_for_graph_rag.py diff --git a/llm_extractor.py b/llm_extractor.py new file mode 100644 index 0000000..03cf93f --- /dev/null +++ b/llm_extractor.py @@ -0,0 +1,44 @@ +import os +import dspy +from typing import List +from pydantic import BaseModel, Field + +class Entity(BaseModel): + name: str = Field( + description="从文本中抽取出的实体名字,以便构建知识图谱" + ) + +class Relationship(BaseModel): + source_entity: str = Field( + description="关系的源实体名字,要求必须是出现在实体列表当中的实体名" + ) + target_entity: str = Field( + description="关系的目标实体名字,要求必须是出现在实体列表当中的实体名" + ) + +class KnowledgeGraph(BaseModel): + entities: List[Entity] = Field( + description="知识图谱当中的一组实体,要求实体名不重复" + ) + relationships: List[Relationship] = Field( + description="知识图谱当中的一组关系,要求关系名不重复" + ) + +class ExtractKG(dspy.Signature): + text: str = dspy.InputField( + desc="基于这段文本抽取实体和关系来形成一个知识图谱" + ) + knowledge_graph: KnowledgeGraph = dspy.OutputField( + desc="基于文本抽取得到的知识图谱" + ) + +# tongyi_lm = dspy.LM( +# model="openai/qwen-plus", +# api_base="https://dashscope.aliyuncs.com/compatible-mode/v1", +# api_key=os.environ.get("DASHSCOPE_API_KEY") +# ) +# dspy.settings.configure(lm=tongyi_lm) + +# kg_extractor = dspy.Predict(ExtractKG) +# pred = kg_extractor(text="OceanBase V4.3.1 版本在 V4.3.0 的基础上带来了新功能与性能优化。本次升级引入了全文索引特性,提高了文档检索效率。新增的功能还包括实时物化视图、物化视图改写及主键物化视图等,进而满足了多场景的数据分析需求。此外,引入的分区管理新机制,如分区交换、外表分区及 MySQL 模式下分区键数据类型的扩展,显著提升了处理大规模数据集的能力。本次版本更新同样针对多模特性(包括 JSON、XML、GIS)进行了升级,增加了对 JSON 多值索引和 JSON 部分更新的支持,进一步促进了异构数据的迁移和融合。为了提升逐步增长的入库需求,OceanBase V4.3.1 提供了增量旁路导入能力,并优化了在具有多局部索引的场景下的 DML 性能,提高了基础统计信息收集效率,并在行采样、小规格 TP 场景取得了显著的性能提升。在资源利用方面,新版本通过引入 CLOG 日志缓存、SQL 临时结果和系统日志压缩特性,实现了更为高效的资源使用。通过补全 MySQL 权限体系和支持操作系统配置检查,加固系统安全。用户体验方面,OceanBase 继续提供高质量服务,提升了资源规格估算的能力,增强了备份的透明度,并加入了对 IPV6 格式的支持,以便于数据库管理和运维工作。在确保与生态系统兼容性的同时,MySQL 与 Oracle 的兼容性也得到了持续增强,支持了 Lateral Derived Tables、MySQL 锁函数、Oracle 视图注释与远程 UDF 调用等特性,这进一步确保了 OceanBase 能够无缝融入现有生态。") +# print(pred) diff --git a/load_docs_for_graph_rag.py b/load_docs_for_graph_rag.py new file mode 100644 index 0000000..4822777 --- /dev/null +++ b/load_docs_for_graph_rag.py @@ -0,0 +1,314 @@ +import os +import uuid +from typing import List, Optional + +from pydantic import BaseModel +from pyobvector import ObVecClient +from langchain.text_splitter import MarkdownHeaderTextSplitter +from neo4j import GraphDatabase + +from concurrent.futures import ThreadPoolExecutor +import asyncio + +from llm_extractor import ExtractKG, KnowledgeGraph, Entity, Relationship +import dspy + +headers_to_split_on = [ + ("#", "Header1"), + ("##", "Header2"), + ("###", "Header3"), + ("####", "Header4"), + ("#####", "Header5"), + ("######", "Header6"), +] + +splitter = MarkdownHeaderTextSplitter( + headers_to_split_on=headers_to_split_on, +) + +neo_uri = "neo4j://localhost:7687" +user = "neo4j" +password = os.environ.get("NEO4J_PASSWORD", "") +graph_db = GraphDatabase.driver(uri=neo_uri, auth=(user, password)) + +CHUNK_INCLUDE_CHUNK = "chunk_include_chunk" +CHUNK_NEXT_CHUNK = "chunk_next_chunk" +DOC_INCLUDE_CHUNK = "doc_include_chunk" +DOC_INCLUDE_DOC = "doc_include_doc" +RELATIONSHIP = "relationship" +CHUNK_INCLUDE_ENTITY = "chunk_include_entity" + +tongyi_lm = dspy.LM( + model="openai/qwen-plus", + api_base="https://dashscope.aliyuncs.com/compatible-mode/v1", + api_key=os.environ.get("DASHSCOPE_API_KEY") +) +dspy.settings.configure(lm=tongyi_lm) + +extract_executor = ThreadPoolExecutor(max_workers=16) + +def reset_graphdb(): + graph_db.execute_query("MATCH (n1:Chunk)-[r1]->(m1:Chunk), (d2:Doc)-[r2]->(n2:Chunk), (d3:Doc)-[r3]->(d4:Doc) DELETE n1,r1,m1,d2,r2,n2,d3,r3,d4") + + +class ChunkWithRelation(BaseModel): + chunk_id: str + content: str + chunk_name: str + lv: int + parent_chunk: Optional["ChunkWithRelation"] + next_chunk: Optional["ChunkWithRelation"] + +class Doc(BaseModel): + doc_id: str + doc_name: str + keywords: List[str] + + + + +def parse_headern_to_lv(headern: str): + return int(headern[len("Header"):]) + +def extract_chunk_kg( + chunk: ChunkWithRelation +): + kg_extractor = dspy.Predict(ExtractKG) + pred = kg_extractor(text=chunk.chunk_name + ": " + chunk.content) + kg: KnowledgeGraph = pred.knowledge_graph + print("======================= extract_chunk_kg: ", kg.model_dump()) + graphdb_upsert_graph(kg) + +async def aget_doc_tree( + file_path: str, +) -> List[ChunkWithRelation]: + chunks_with_rel: List[ChunkWithRelation] = [] + level_stk: List[ChunkWithRelation] = [] + with open(file_path, "r", encoding="utf-8") as f: + file_content = f.read() + chunks = splitter.split_text(file_content) + for chunk in chunks: + metadata_keys = list(chunk.metadata.keys()) + chunk_name = '-'.join(list(chunk.metadata.values())) + # TODO: if chunk_name does not exists + chunk_lv = ( + 0 if len(metadata_keys) == 0 + else parse_headern_to_lv(metadata_keys[-1]) + ) + + # pop stk if necessary + while len(level_stk) > 0 and chunk_lv <= level_stk[-1].lv: + level_stk.pop() + + chunk_with_rel = ChunkWithRelation( + chunk_id=str(uuid.uuid4()), + content=chunk.page_content, + chunk_name=chunk_name, + lv=chunk_lv, + parent_chunk=( + None if len(level_stk) == 0 + else level_stk[-1] + ), + next_chunk=None, + ) + if len(chunks_with_rel) > 0: + chunks_with_rel[-1].next_chunk = chunk_with_rel + chunks_with_rel.append(chunk_with_rel) + level_stk.append(chunk_with_rel) + return chunks_with_rel + +def graphdb_upsert_chunks( + chunks: List[ChunkWithRelation], +): + def _create_chunk(tx, chunk_batch): + for chunk in chunk_batch: + query = ( + "CREATE (chunk: Chunk {id: $id, name: $name, content: $content})" + ) + tx.run(query, id=chunk.chunk_id, name=chunk.chunk_name, content=chunk.content) + + with graph_db.session() as session: + session.execute_write(_create_chunk, chunks) + +def graphdb_upsert_chunk_rels( + chunks: List[ChunkWithRelation], +): + def _create_chunk_rels(tx, chunk_batch): + for chunk in chunk_batch: + parent_chunk = chunk.parent_chunk + next_chunk = chunk.next_chunk + if parent_chunk: + query = ( + f"MATCH (s: Chunk {{id: $parent_id}}), (t: Chunk {{id: $id}}) " \ + f"CREATE (s)-[r: {CHUNK_INCLUDE_CHUNK}]->(t)" + ) + tx.run(query, parent_id=parent_chunk.chunk_id, id=chunk.chunk_id) + if next_chunk: + query = ( + f"MATCH (s: Chunk {{id: $id}}), (t: Chunk {{id: $next_id}}) " \ + f"CREATE (s)-[r: {CHUNK_NEXT_CHUNK}]->(t)" + ) + tx.run(query, id=chunk.chunk_id, next_id=next_chunk.chunk_id) + with graph_db.session() as session: + session.execute_write(_create_chunk_rels, chunks) + +def graphdb_upsert_doc( + doc: Doc +): + def _create_doc(tx, doc): + query = ( + "CREATE (d: Doc {id: $id, name: $name, keywords: $keywords})" + ) + tx.run(query, id=doc.doc_id, name=doc.doc_name, keywords=doc.keywords) + + with graph_db.session() as session: + session.execute_write(_create_doc, doc) + +def graphdb_upsert_doc_include_chunk( + doc: Doc, + hd_chunks: List[ChunkWithRelation] +): + def _create_doc_include_chunks(tx, doc, chunks): + for chunk in chunks: + query = ( + f"MATCH (s: Doc {{id: $doc_id}}), (t: Chunk {{id: $chunk_id}}) " \ + f"CREATE (s)-[r: {DOC_INCLUDE_CHUNK}]->(t)" + ) + tx.run(query, doc_id=doc.doc_id, chunk_id=chunk.chunk_id) + with graph_db.session() as session: + session.execute_write(_create_doc_include_chunks, doc, hd_chunks) + +def graphdb_upsert_doc_include_doc( + parent_doc: Doc, + doc: Doc +): + def _create_doc_include_doc(tx, par_doc, doc): + query = ( + f"MATCH (s: Doc {{id: $par_id}}), (t: Doc {{id: $doc_id}}) " \ + f"CREATE (s)-[r: {DOC_INCLUDE_DOC}]->(t)" + ) + tx.run(query, par_id=par_doc.doc_id, doc_id=doc.doc_id) + with graph_db.session() as session: + session.execute_write(_create_doc_include_doc, parent_doc, doc) + +def graphdb_upsert_entities( + entities: List[Entity] +): + def _create_entities(tx, ents: List[Entity]): + for ent in ents: + query = ( + f"CREATE (e: Entity {{name: $name, description: $desc}})" + ) + tx.run(query, name=ent.name, desc=ent.description) + with graph_db.session() as session: + session.execute_write(_create_entities, entities) + +def graphdb_upsert_relations( + relations: List[Relationship] +): + def _create_rels(tx, rels: List[Relationship]): + for rel in rels: + query = ( + f"MATCH (s: {{name: $sname}}), (t: {{name: $tname}}) " \ + f"CREATE (s)-[r: {RELATIONSHIP} {{description: $rdesc}}]->(t)" + ) + tx.run(query, sname=rel.source_entity, tname=rel.target_entity, rdesc=rel.relationship_desc) + with graph_db.session() as session: + session.execute_write(_create_rels, relations) + +def graphdb_upsert_chunk_include_entities( + chunk_id: str, + entities: List[Entity] +): + def _create_chunk_include_entities(tx, cid, ents): + for ent in ents: + query = ( + f"MATCH (c: Chunk {{id: $id}}), (e: Entity {{name: $name}}) " \ + f"CREATE (c)-[r: {CHUNK_INCLUDE_ENTITY}]->(e)" + ) + tx.run(query, id=cid, name=ent.name) + with graph_db.session() as session: + session.execute_write(_create_chunk_include_entities, chunk_id, entities) + +def graphdb_upsert_graph( + graph: KnowledgeGraph, + chunk: ChunkWithRelation, +): + graphdb_upsert_entities(graph.entities) + graphdb_upsert_relations(graph.relationships) + graphdb_upsert_chunk_include_entities(chunk.chunk_id, graph.entities) + + +async def aload_doc_graph( + doc_root_path: str, + relative_path: str, +): + print("=========== start_load_doc_graph: ", doc_root_path, relative_path) + file_path = os.path.join(doc_root_path, relative_path) + + # Get Document tree + chunks_with_rel = await aget_doc_tree(file_path=file_path) + + # extract_tasks = [] + # for chunk in chunks_with_rel: + # extract_tasks.append( + # await loop.run_in_executor(extract_executor, extract_chunk_kg, chunk) + # ) + # await asyncio.gather(*extract_tasks) + + graphdb_upsert_chunks(chunks_with_rel) + graphdb_upsert_chunk_rels(chunks_with_rel) + + return chunks_with_rel[0] if len(chunks_with_rel) > 0 else None + +async def aload_doc( + doc_root_path: str, + doc_repo_name: str, + parent_doc: Optional[Doc] = None, +): + files = [] + dirs = [] + with os.scandir(doc_root_path) as entries: + for entry in entries: + if entry.is_file() and entry.name.endswith(".md"): + files.append(entry.name) + elif entry.is_dir(): + new_doc_root_path = os.path.join(doc_root_path, entry.name) + dirs.append(entry.name) + + load_doc_graph_tasks = [] + for file in files: + load_doc_graph_tasks.append( + aload_doc_graph(doc_root_path, file) + ) + if len(load_doc_graph_tasks) > 0: + hd_chunks = await asyncio.gather(*load_doc_graph_tasks) + + doc = Doc( + doc_id=str(uuid.uuid4()), + doc_name=doc_repo_name[doc_repo_name.find('.') + 1:], + keywords=[], + ) + graphdb_upsert_doc(doc) + if len(hd_chunks) > 0: + graphdb_upsert_doc_include_chunk(doc, hd_chunks) + if parent_doc: + graphdb_upsert_doc_include_doc(parent_doc, doc) + + print(f"Doc: {{doc_name: {doc.doc_name}, doc_path: {doc_root_path}}}") + + load_doc_tasks = [] + for dir in dirs: + new_doc_root_path = os.path.join(doc_root_path, dir) + load_doc_tasks.append( + aload_doc(new_doc_root_path, dir, doc) + ) + if len(load_doc_tasks) > 0: + await asyncio.gather(*load_doc_tasks) + + +reset_graphdb() +loop = asyncio.get_event_loop() +loop.run_until_complete( + aload_doc(doc_root_path="./doc_test", doc_repo_name="OceanBase") +) -- Gitee From a5b6aa1652d4f8faabc41f6b54d70ab2b49148e5 Mon Sep 17 00:00:00 2001 From: "shanhaikang.shk" Date: Thu, 28 Nov 2024 10:34:05 +0800 Subject: [PATCH 2/3] fix dspy Signed-off-by: shanhaikang.shk --- llm_extractor.py | 36 +++++++++++++++++- load_docs_for_graph_rag.py | 75 ++++++++++++++++++++++++++------------ 2 files changed, 86 insertions(+), 25 deletions(-) diff --git a/llm_extractor.py b/llm_extractor.py index 03cf93f..135cf8a 100644 --- a/llm_extractor.py +++ b/llm_extractor.py @@ -15,6 +15,9 @@ class Relationship(BaseModel): target_entity: str = Field( description="关系的目标实体名字,要求必须是出现在实体列表当中的实体名" ) + relation_name: str = Field( + description="关系名,一般是一个谓词" + ) class KnowledgeGraph(BaseModel): entities: List[Entity] = Field( @@ -28,8 +31,37 @@ class ExtractKG(dspy.Signature): text: str = dspy.InputField( desc="基于这段文本抽取实体和关系来形成一个知识图谱" ) - knowledge_graph: KnowledgeGraph = dspy.OutputField( - desc="基于文本抽取得到的知识图谱" + entities: List[str] = dspy.OutputField( + desc="知识图谱当中的一组实体,每个实体的格式为'实体名',要求实体名不重复,要求数量不超过10个" + ) + relationships: List[str] = dspy.OutputField( + desc="知识图谱当中的一组关系,每个关系的格式为'来源实体名#关系名#目标实体名',来源实体名、关系名、目标实体名中存在'#'时使用'_'替换,要求来源实体名和目标实体名必须在实体列表中包含" + ) + # knowledge_graph: KnowledgeGraph = dspy.OutputField( + # desc="基于文本抽取得到的知识图谱" + # ) + +def parse_extract_output_to_kg(pred) -> KnowledgeGraph: + entities = [] + for ent in pred.entities: + entities.append(Entity( + name=ent + )) + + rels = [] + for rel in pred.relationships: + rel_eles = rel.split('#') + if len(rel_eles) != 3: + continue + rels.append(Relationship( + source_entity=rel_eles[0], + target_entity=rel_eles[2], + relation_name=rel_eles[1], + )) + + return KnowledgeGraph( + entities=entities, + relationships=rels ) # tongyi_lm = dspy.LM( diff --git a/load_docs_for_graph_rag.py b/load_docs_for_graph_rag.py index 4822777..be19cbb 100644 --- a/load_docs_for_graph_rag.py +++ b/load_docs_for_graph_rag.py @@ -3,14 +3,16 @@ import uuid from typing import List, Optional from pydantic import BaseModel -from pyobvector import ObVecClient +from pyobvector import ObVecClient, VECTOR +from sqlalchemy import Column, Integer, String +from sqlalchemy.dialects.mysql import TEXT from langchain.text_splitter import MarkdownHeaderTextSplitter from neo4j import GraphDatabase from concurrent.futures import ThreadPoolExecutor import asyncio -from llm_extractor import ExtractKG, KnowledgeGraph, Entity, Relationship +from llm_extractor import ExtractKG, KnowledgeGraph, Entity, Relationship, parse_extract_output_to_kg import dspy headers_to_split_on = [ @@ -31,6 +33,28 @@ user = "neo4j" password = os.environ.get("NEO4J_PASSWORD", "") graph_db = GraphDatabase.driver(uri=neo_uri, auth=(user, password)) +ob = ObVecClient() +cols = [ + Column("id", Integer, primary_key=True, autoincrement=False), + Column("content_embedding", VECTOR(1024)), + Column("content", TEXT), + Column("chunk_id", String(128)), +] +CONTENT_EMBED_TABLE = "content_embed_table" +if not ob.check_table_exists(CONTENT_EMBED_TABLE): + ob.create_table( + CONTENT_EMBED_TABLE, columns=cols + ) + # create vector index + ob.create_index( + CONTENT_EMBED_TABLE, + is_vec_index=True, + index_name="vidx", + column_names=["content_embedding"], + vidx_params="distance=l2, type=hnsw, lib=vsag", + ) + +INCLUDE = "include" CHUNK_INCLUDE_CHUNK = "chunk_include_chunk" CHUNK_NEXT_CHUNK = "chunk_next_chunk" DOC_INCLUDE_CHUNK = "doc_include_chunk" @@ -48,7 +72,7 @@ dspy.settings.configure(lm=tongyi_lm) extract_executor = ThreadPoolExecutor(max_workers=16) def reset_graphdb(): - graph_db.execute_query("MATCH (n1:Chunk)-[r1]->(m1:Chunk), (d2:Doc)-[r2]->(n2:Chunk), (d3:Doc)-[r3]->(d4:Doc) DELETE n1,r1,m1,d2,r2,n2,d3,r3,d4") + graph_db.execute_query("MATCH (n) DETACH DELETE n") class ChunkWithRelation(BaseModel): @@ -73,11 +97,14 @@ def parse_headern_to_lv(headern: str): def extract_chunk_kg( chunk: ChunkWithRelation ): - kg_extractor = dspy.Predict(ExtractKG) - pred = kg_extractor(text=chunk.chunk_name + ": " + chunk.content) - kg: KnowledgeGraph = pred.knowledge_graph + try: + kg_extractor = dspy.Predict(ExtractKG) + pred = kg_extractor(text=chunk.chunk_name + ": " + chunk.content) + except Exception as e: + return + kg: KnowledgeGraph = parse_extract_output_to_kg(pred) print("======================= extract_chunk_kg: ", kg.model_dump()) - graphdb_upsert_graph(kg) + graphdb_upsert_graph(kg, chunk) async def aget_doc_tree( file_path: str, @@ -85,11 +112,13 @@ async def aget_doc_tree( chunks_with_rel: List[ChunkWithRelation] = [] level_stk: List[ChunkWithRelation] = [] with open(file_path, "r", encoding="utf-8") as f: + file_name = os.path.basename(file_path) file_content = f.read() chunks = splitter.split_text(file_content) for chunk in chunks: metadata_keys = list(chunk.metadata.keys()) chunk_name = '-'.join(list(chunk.metadata.values())) + chunk_name = '-'.join([file_name, chunk_name]) # TODO: if chunk_name does not exists chunk_lv = ( 0 if len(metadata_keys) == 0 @@ -140,13 +169,13 @@ def graphdb_upsert_chunk_rels( if parent_chunk: query = ( f"MATCH (s: Chunk {{id: $parent_id}}), (t: Chunk {{id: $id}}) " \ - f"CREATE (s)-[r: {CHUNK_INCLUDE_CHUNK}]->(t)" + f"CREATE (s)-[r: {INCLUDE} {{description: '{CHUNK_INCLUDE_CHUNK}'}}]->(t)" ) tx.run(query, parent_id=parent_chunk.chunk_id, id=chunk.chunk_id) if next_chunk: query = ( f"MATCH (s: Chunk {{id: $id}}), (t: Chunk {{id: $next_id}}) " \ - f"CREATE (s)-[r: {CHUNK_NEXT_CHUNK}]->(t)" + f"CREATE (s)-[r: {INCLUDE} {{description: '{CHUNK_NEXT_CHUNK}'}}]->(t)" ) tx.run(query, id=chunk.chunk_id, next_id=next_chunk.chunk_id) with graph_db.session() as session: @@ -172,7 +201,7 @@ def graphdb_upsert_doc_include_chunk( for chunk in chunks: query = ( f"MATCH (s: Doc {{id: $doc_id}}), (t: Chunk {{id: $chunk_id}}) " \ - f"CREATE (s)-[r: {DOC_INCLUDE_CHUNK}]->(t)" + f"CREATE (s)-[r: {INCLUDE} {{description: '{DOC_INCLUDE_CHUNK}'}}]->(t)" ) tx.run(query, doc_id=doc.doc_id, chunk_id=chunk.chunk_id) with graph_db.session() as session: @@ -185,7 +214,7 @@ def graphdb_upsert_doc_include_doc( def _create_doc_include_doc(tx, par_doc, doc): query = ( f"MATCH (s: Doc {{id: $par_id}}), (t: Doc {{id: $doc_id}}) " \ - f"CREATE (s)-[r: {DOC_INCLUDE_DOC}]->(t)" + f"CREATE (s)-[r: {INCLUDE} {{description: '{DOC_INCLUDE_DOC}'}}]->(t)" ) tx.run(query, par_id=par_doc.doc_id, doc_id=doc.doc_id) with graph_db.session() as session: @@ -197,9 +226,9 @@ def graphdb_upsert_entities( def _create_entities(tx, ents: List[Entity]): for ent in ents: query = ( - f"CREATE (e: Entity {{name: $name, description: $desc}})" + f"MERGE (e: Entity {{name: $name}})" ) - tx.run(query, name=ent.name, desc=ent.description) + tx.run(query, name=ent.name) with graph_db.session() as session: session.execute_write(_create_entities, entities) @@ -209,10 +238,10 @@ def graphdb_upsert_relations( def _create_rels(tx, rels: List[Relationship]): for rel in rels: query = ( - f"MATCH (s: {{name: $sname}}), (t: {{name: $tname}}) " \ + f"MATCH (s: Entity {{name: $sname}}), (t: Entity {{name: $tname}}) " \ f"CREATE (s)-[r: {RELATIONSHIP} {{description: $rdesc}}]->(t)" ) - tx.run(query, sname=rel.source_entity, tname=rel.target_entity, rdesc=rel.relationship_desc) + tx.run(query, sname=rel.source_entity, tname=rel.target_entity, rdesc=rel.relation_name) with graph_db.session() as session: session.execute_write(_create_rels, relations) @@ -224,7 +253,7 @@ def graphdb_upsert_chunk_include_entities( for ent in ents: query = ( f"MATCH (c: Chunk {{id: $id}}), (e: Entity {{name: $name}}) " \ - f"CREATE (c)-[r: {CHUNK_INCLUDE_ENTITY}]->(e)" + f"CREATE (c)-[r: {INCLUDE} {{description: '{CHUNK_INCLUDE_ENTITY}'}}]->(e)" ) tx.run(query, id=cid, name=ent.name) with graph_db.session() as session: @@ -248,17 +277,17 @@ async def aload_doc_graph( # Get Document tree chunks_with_rel = await aget_doc_tree(file_path=file_path) - - # extract_tasks = [] - # for chunk in chunks_with_rel: - # extract_tasks.append( - # await loop.run_in_executor(extract_executor, extract_chunk_kg, chunk) - # ) - # await asyncio.gather(*extract_tasks) graphdb_upsert_chunks(chunks_with_rel) graphdb_upsert_chunk_rels(chunks_with_rel) + extract_tasks = [] + for chunk in chunks_with_rel: + extract_tasks.append( + loop.run_in_executor(extract_executor, extract_chunk_kg, chunk) + ) + await asyncio.gather(*extract_tasks) + return chunks_with_rel[0] if len(chunks_with_rel) > 0 else None async def aload_doc( -- Gitee From 250a21182f9a73df8410f822cdca8a132de21481 Mon Sep 17 00:00:00 2001 From: "shanhaikang.shk" Date: Thu, 28 Nov 2024 17:13:52 +0800 Subject: [PATCH 3/3] ob graphrag demo finish Signed-off-by: shanhaikang.shk --- llm_extractor.py | 19 +++-- load_docs_for_graph_rag.py | 58 ++++++++++++++-- query_for_graph_rag.py | 139 +++++++++++++++++++++++++++++++++++++ 3 files changed, 201 insertions(+), 15 deletions(-) create mode 100644 query_for_graph_rag.py diff --git a/llm_extractor.py b/llm_extractor.py index 135cf8a..2be9b72 100644 --- a/llm_extractor.py +++ b/llm_extractor.py @@ -41,6 +41,14 @@ class ExtractKG(dspy.Signature): # desc="基于文本抽取得到的知识图谱" # ) +class ExtractKeywords(dspy.Signature): + text: str = dspy.InputField( + desc="基于这段文本抽取关键词列表以便于在知识图谱中搜索" + ) + keywords: List[str] = dspy.OutputField( + desc="一组关键词,要求每个关键词清晰明确,尽可能多的包含同义词" + ) + def parse_extract_output_to_kg(pred) -> KnowledgeGraph: entities = [] for ent in pred.entities: @@ -63,14 +71,3 @@ def parse_extract_output_to_kg(pred) -> KnowledgeGraph: entities=entities, relationships=rels ) - -# tongyi_lm = dspy.LM( -# model="openai/qwen-plus", -# api_base="https://dashscope.aliyuncs.com/compatible-mode/v1", -# api_key=os.environ.get("DASHSCOPE_API_KEY") -# ) -# dspy.settings.configure(lm=tongyi_lm) - -# kg_extractor = dspy.Predict(ExtractKG) -# pred = kg_extractor(text="OceanBase V4.3.1 版本在 V4.3.0 的基础上带来了新功能与性能优化。本次升级引入了全文索引特性,提高了文档检索效率。新增的功能还包括实时物化视图、物化视图改写及主键物化视图等,进而满足了多场景的数据分析需求。此外,引入的分区管理新机制,如分区交换、外表分区及 MySQL 模式下分区键数据类型的扩展,显著提升了处理大规模数据集的能力。本次版本更新同样针对多模特性(包括 JSON、XML、GIS)进行了升级,增加了对 JSON 多值索引和 JSON 部分更新的支持,进一步促进了异构数据的迁移和融合。为了提升逐步增长的入库需求,OceanBase V4.3.1 提供了增量旁路导入能力,并优化了在具有多局部索引的场景下的 DML 性能,提高了基础统计信息收集效率,并在行采样、小规格 TP 场景取得了显著的性能提升。在资源利用方面,新版本通过引入 CLOG 日志缓存、SQL 临时结果和系统日志压缩特性,实现了更为高效的资源使用。通过补全 MySQL 权限体系和支持操作系统配置检查,加固系统安全。用户体验方面,OceanBase 继续提供高质量服务,提升了资源规格估算的能力,增强了备份的透明度,并加入了对 IPV6 格式的支持,以便于数据库管理和运维工作。在确保与生态系统兼容性的同时,MySQL 与 Oracle 的兼容性也得到了持续增强,支持了 Lateral Derived Tables、MySQL 锁函数、Oracle 视图注释与远程 UDF 调用等特性,这进一步确保了 OceanBase 能够无缝融入现有生态。") -# print(pred) diff --git a/load_docs_for_graph_rag.py b/load_docs_for_graph_rag.py index be19cbb..83b7249 100644 --- a/load_docs_for_graph_rag.py +++ b/load_docs_for_graph_rag.py @@ -1,4 +1,5 @@ import os +import requests import uuid from typing import List, Optional @@ -35,13 +36,13 @@ graph_db = GraphDatabase.driver(uri=neo_uri, auth=(user, password)) ob = ObVecClient() cols = [ - Column("id", Integer, primary_key=True, autoincrement=False), + Column("chunk_id", String(128), primary_key=True, autoincrement=False), Column("content_embedding", VECTOR(1024)), Column("content", TEXT), - Column("chunk_id", String(128)), ] CONTENT_EMBED_TABLE = "content_embed_table" if not ob.check_table_exists(CONTENT_EMBED_TABLE): + print(f"################### create table {CONTENT_EMBED_TABLE}") ob.create_table( CONTENT_EMBED_TABLE, columns=cols ) @@ -53,6 +54,7 @@ if not ob.check_table_exists(CONTENT_EMBED_TABLE): column_names=["content_embedding"], vidx_params="distance=l2, type=hnsw, lib=vsag", ) +OB_DEFAULT_BATCH_SIZE = 10 INCLUDE = "include" CHUNK_INCLUDE_CHUNK = "chunk_include_chunk" @@ -71,6 +73,9 @@ dspy.settings.configure(lm=tongyi_lm) extract_executor = ThreadPoolExecutor(max_workers=16) +def reset_ob(): + ob.perform_raw_text_sql(f"DROP TABLE {CONTENT_EMBED_TABLE}") + def reset_graphdb(): graph_db.execute_query("MATCH (n) DETACH DELETE n") @@ -94,6 +99,49 @@ class Doc(BaseModel): def parse_headern_to_lv(headern: str): return int(headern[len("Header"):]) +def embedding(queries: List[str]): + res = requests.post( + os.environ.get("REMOTE_BGE_URL", ""), + json={"model": "bge-m3", "input": queries}, + headers={ + "X-Token": os.environ.get("REMOTE_BGE_TOKEN", "") + }, + ) + try: + data = res.json() + except Exception as e: + print(f"XXXXXXXXXXXXXXXXXXXXXX {res.text} XXXXXXXXXXXXXXXXXXXX") + raise e + return data["embeddings"] + +def embed_chunks( + chunks: List[ChunkWithRelation], +): + chunk_contents = [chunk.content for chunk in chunks] + vecs = embedding(chunk_contents) + datas = [ + { + "chunk_id": chunk.chunk_id, + "content_embedding": vec, + "content": chunk.content, + } + for (chunk, vec) in zip(chunks, vecs) + ] + ob.insert(CONTENT_EMBED_TABLE, datas) + + +async def embed_chunks_with_batch_size( + executor: ThreadPoolExecutor, + chunks: List[ChunkWithRelation], + batch_size: int = OB_DEFAULT_BATCH_SIZE, +): + tasks = [] + for start_idx in range(0, len(chunks), batch_size): + tasks.append( + loop.run_in_executor(executor, embed_chunks, chunks[start_idx : start_idx+batch_size]) + ) + await asyncio.gather(*tasks) + def extract_chunk_kg( chunk: ChunkWithRelation ): @@ -239,7 +287,7 @@ def graphdb_upsert_relations( for rel in rels: query = ( f"MATCH (s: Entity {{name: $sname}}), (t: Entity {{name: $tname}}) " \ - f"CREATE (s)-[r: {RELATIONSHIP} {{description: $rdesc}}]->(t)" + f"MERGE (s)-[r: {RELATIONSHIP} {{description: $rdesc}}]->(t)" ) tx.run(query, sname=rel.source_entity, tname=rel.target_entity, rdesc=rel.relation_name) with graph_db.session() as session: @@ -288,6 +336,8 @@ async def aload_doc_graph( ) await asyncio.gather(*extract_tasks) + await embed_chunks_with_batch_size(extract_executor, chunks_with_rel) + return chunks_with_rel[0] if len(chunks_with_rel) > 0 else None async def aload_doc( @@ -334,7 +384,7 @@ async def aload_doc( ) if len(load_doc_tasks) > 0: await asyncio.gather(*load_doc_tasks) - + reset_graphdb() loop = asyncio.get_event_loop() diff --git a/query_for_graph_rag.py b/query_for_graph_rag.py new file mode 100644 index 0000000..fc4a030 --- /dev/null +++ b/query_for_graph_rag.py @@ -0,0 +1,139 @@ +import os +import dspy +import requests +from typing import List + +from llm_extractor import ExtractKeywords +import dspy +from neo4j import GraphDatabase +from pyobvector import ObVecClient, VECTOR +from sqlalchemy import func + +tongyi_lm = dspy.LM( + model="openai/qwen-plus", + api_base="https://dashscope.aliyuncs.com/compatible-mode/v1", + api_key=os.environ.get("DASHSCOPE_API_KEY") +) +dspy.settings.configure(lm=tongyi_lm) + +neo_uri = "neo4j://localhost:7687" +user = "neo4j" +password = os.environ.get("NEO4J_PASSWORD", "") +graph_db = GraphDatabase.driver(uri=neo_uri, auth=(user, password)) + +ob = ObVecClient() +CONTENT_EMBED_TABLE = "content_embed_table" + +VECTOR_RECALL_TOPK = 3 + +def extract_keywords(query: str): + try: + keywords_extractor = dspy.Predict(ExtractKeywords) + pred = keywords_extractor(text=query) + return pred.keywords + except Exception as e: + print("XXXXXXXXXX failed to extract keywords") + raise e + +# print(extract_keywords("OceanBase是什么")) + +def query_graphdb(query: str): + keywords = extract_keywords(query) + print(keywords) + with graph_db.session() as session: + # res = session.run( + # f"MATCH (e: Entity) WHERE e.name IN {keywords} RETURN e" + # ) + # ents = [r for r in res] + res = session.run( + f"MATCH (c: Chunk)-[r]->(e: Entity) WHERE e.name IN {keywords} RETURN c" + ) + leaf_chunks = [r for r in res] + return leaf_chunks + +def query_graphdb_entities_and_rels_with_chunk_ids( + chunk_ids: List[str] +): + with graph_db.session() as session: + chunks_res = session.run( + f"MATCH (c: Chunk)-[]->(e1:Entity) WHERE c.id IN {chunk_ids} " \ + f"MATCH (c: Chunk)-[]->(e2:Entity) WHERE c.id IN {chunk_ids} " \ + f"MATCH p=(e1)-[r]->(e2) RETURN p" + ) + relations = [] + for r in chunks_res: + start_ent = r['p'].start_node['name'] + end_ent = r['p'].end_node['name'] + for rel in r['p'].relationships: + relations.append(start_ent + "#" + rel['description'] + "#" + end_ent) + return relations + +def embedding(queries: List[str]): + res = requests.post( + os.environ.get("REMOTE_BGE_URL", ""), + json={"model": "bge-m3", "input": queries}, + headers={ + "X-Token": os.environ.get("REMOTE_BGE_TOKEN", "") + }, + ) + try: + data = res.json() + except Exception as e: + print(f"XXXXXXXXXXXXXXXXXXXXXX {res.text} XXXXXXXXXXXXXXXXXXXX") + raise e + return data["embeddings"] + +def query_vecdb(query: str): + vec = embedding([query])[0] + res = ob.ann_search( + table_name=CONTENT_EMBED_TABLE, + vec_data=vec, + vec_column_name="content_embedding", + distance_func=func.l2_distance, + with_dist=False, + topk=VECTOR_RECALL_TOPK, + output_column_names=["chunk_id", "content"], + ) + return [ + { + "chunk_id": r[0], + "content": r[1] + } + for r in res + ] + +PROMPT = """ + 你是一个知识库问答助手,非常擅长利用文档上下文以及文档中实体的关系为用户提供详实、正确的问答服务 + + 以下是相关的文档上下文: + {context} + + 以下是文档上下文中的实体关系(每一行表示一组实体关系,格式为'起始实体#关系#目标实体'): + {relations} + + 以下是用户的问题: + {query} + + 现在请回答用户的问题: +""" + +def response_query( + query: str, +): + vres = query_vecdb(query) + chunk_ids = [v["chunk_id"] for v in vres] + rels = query_graphdb_entities_and_rels_with_chunk_ids(chunk_ids) + # print("\n".join(list(set(rels)))) + prompt = PROMPT.format( + context="\n".join([v["content"] for v in vres]), + relations="\n".join(list(set(rels))), + query=query, + ) + print(prompt) + return tongyi_lm(prompt) + +while True: + query = input("> ") + res = response_query(query) + print(f"====================================\n{res}\n==============================") + -- Gitee