| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071 |
- # -*- coding: utf-8 -*-
- """Parent Document RAG - retrieve child chunks, return parent documents."""
- from .base import BaseRAG, Document
- from loguru import logger
- class ParentDocumentRAG(BaseRAG):
- def __init__(self, parent_chunk_size=1500, **kwargs):
- super().__init__(**kwargs)
- self.parent_chunk_size = parent_chunk_size
- self.parent_docs = []
- self.child_to_parent = {}
- def index_documents(self, documents):
- self.parent_docs = []
- self.child_to_parent = {}
-
- for i, doc in enumerate(documents):
- words = doc.page_content.split()
- for j in range(0, len(words), self.parent_chunk_size):
- parent_text = " ".join(words[j:j + self.parent_chunk_size])
- parent_doc = Document(
- page_content=parent_text,
- metadata=dict(doc.metadata, chunk_index=j, is_parent=True)
- )
- self.parent_docs.append(parent_doc)
-
- child_texts = []
- for doc in documents:
- words = doc.page_content.split()
- for j in range(0, len(words), self.parent_chunk_size):
- chunk_words = words[j:j + self.parent_chunk_size]
- for k in range(0, len(chunk_words), 512):
- child_text = " ".join(chunk_words[k:k + 512])
- child_id = len(child_texts)
- child_texts.append(child_text)
- self.child_to_parent[child_id] = child_id // 3
-
- if child_texts:
- embeddings = self.embedding_model.embed_documents(child_texts)
- child_docs = [
- Document(page_content=text, metadata={"is_parent": False})
- for text in child_texts
- ]
- self.vector_store.add_documents(child_docs, embeddings)
-
- logger.info("ParentDocumentRAG: {} parents, {} children".format(len(self.parent_docs), len(child_texts)))
- def retrieve(self, query, k=10):
- query_embedding = self.embedding_model.embed_query(query)
- child_results = self.vector_store.similarity_search(query_embedding, k * 2)
-
- parent_map = {}
- for child_doc, score in child_results:
- for i, parent_doc in enumerate(self.parent_docs):
- if child_doc.page_content[:50] in parent_doc.page_content:
- if i not in parent_map or score > parent_map[i]:
- parent_map[i] = (parent_doc, score)
- break
-
- parent_results = sorted(parent_map.values(), key=lambda x: x[1], reverse=True)
- return self._deduplicate_results(parent_results, k)
- def generate(self, query, context):
- prompt = (
- "根据以下参考文档(父子文档检索,包含完整上下文),回答问题。\n\n"
- "参考文档:\n{}\n\n"
- "问题:{}\n\n"
- "请详细回答。".format(context, query)
- )
- return self._call_llm(prompt)
|