# -*- coding: utf-8 -*- """Parent Document RAG - retrieve child chunks, return parent documents.""" from .base import BaseRAG, Document from loguru import logger class ParentDocumentRAG(BaseRAG): def __init__(self, parent_chunk_size=1500, **kwargs): super().__init__(**kwargs) self.parent_chunk_size = parent_chunk_size self.parent_docs = [] self.child_to_parent = {} def index_documents(self, documents): self.parent_docs = [] self.child_to_parent = {} for i, doc in enumerate(documents): words = doc.page_content.split() for j in range(0, len(words), self.parent_chunk_size): parent_text = " ".join(words[j:j + self.parent_chunk_size]) parent_doc = Document( page_content=parent_text, metadata=dict(doc.metadata, chunk_index=j, is_parent=True) ) self.parent_docs.append(parent_doc) child_texts = [] for doc in documents: words = doc.page_content.split() for j in range(0, len(words), self.parent_chunk_size): chunk_words = words[j:j + self.parent_chunk_size] for k in range(0, len(chunk_words), 512): child_text = " ".join(chunk_words[k:k + 512]) child_id = len(child_texts) child_texts.append(child_text) self.child_to_parent[child_id] = child_id // 3 if child_texts: embeddings = self.embedding_model.embed_documents(child_texts) child_docs = [ Document(page_content=text, metadata={"is_parent": False}) for text in child_texts ] self.vector_store.add_documents(child_docs, embeddings) logger.info("ParentDocumentRAG: {} parents, {} children".format(len(self.parent_docs), len(child_texts))) def retrieve(self, query, k=10): query_embedding = self.embedding_model.embed_query(query) child_results = self.vector_store.similarity_search(query_embedding, k * 2) parent_map = {} for child_doc, score in child_results: for i, parent_doc in enumerate(self.parent_docs): if child_doc.page_content[:50] in parent_doc.page_content: if i not in parent_map or score > parent_map[i]: parent_map[i] = (parent_doc, score) break parent_results = sorted(parent_map.values(), key=lambda x: x[1], reverse=True) return self._deduplicate_results(parent_results, k) def generate(self, query, context): prompt = ( "根据以下参考文档(父子文档检索,包含完整上下文),回答问题。\n\n" "参考文档:\n{}\n\n" "问题:{}\n\n" "请详细回答。".format(context, query) ) return self._call_llm(prompt)