parent_document_rag.py 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. # -*- coding: utf-8 -*-
  2. """Parent Document RAG - retrieve child chunks, return parent documents."""
  3. from .base import BaseRAG, Document
  4. from loguru import logger
  5. class ParentDocumentRAG(BaseRAG):
  6. def __init__(self, parent_chunk_size=1500, **kwargs):
  7. super().__init__(**kwargs)
  8. self.parent_chunk_size = parent_chunk_size
  9. self.parent_docs = []
  10. self.child_to_parent = {}
  11. def index_documents(self, documents):
  12. self.parent_docs = []
  13. self.child_to_parent = {}
  14. for i, doc in enumerate(documents):
  15. words = doc.page_content.split()
  16. for j in range(0, len(words), self.parent_chunk_size):
  17. parent_text = " ".join(words[j:j + self.parent_chunk_size])
  18. parent_doc = Document(
  19. page_content=parent_text,
  20. metadata=dict(doc.metadata, chunk_index=j, is_parent=True)
  21. )
  22. self.parent_docs.append(parent_doc)
  23. child_texts = []
  24. for doc in documents:
  25. words = doc.page_content.split()
  26. for j in range(0, len(words), self.parent_chunk_size):
  27. chunk_words = words[j:j + self.parent_chunk_size]
  28. for k in range(0, len(chunk_words), 512):
  29. child_text = " ".join(chunk_words[k:k + 512])
  30. child_id = len(child_texts)
  31. child_texts.append(child_text)
  32. self.child_to_parent[child_id] = child_id // 3
  33. if child_texts:
  34. embeddings = self.embedding_model.embed_documents(child_texts)
  35. child_docs = [
  36. Document(page_content=text, metadata={"is_parent": False})
  37. for text in child_texts
  38. ]
  39. self.vector_store.add_documents(child_docs, embeddings)
  40. logger.info("ParentDocumentRAG: {} parents, {} children".format(len(self.parent_docs), len(child_texts)))
  41. def retrieve(self, query, k=10):
  42. query_embedding = self.embedding_model.embed_query(query)
  43. child_results = self.vector_store.similarity_search(query_embedding, k * 2)
  44. parent_map = {}
  45. for child_doc, score in child_results:
  46. for i, parent_doc in enumerate(self.parent_docs):
  47. if child_doc.page_content[:50] in parent_doc.page_content:
  48. if i not in parent_map or score > parent_map[i]:
  49. parent_map[i] = (parent_doc, score)
  50. break
  51. parent_results = sorted(parent_map.values(), key=lambda x: x[1], reverse=True)
  52. return self._deduplicate_results(parent_results, k)
  53. def generate(self, query, context):
  54. prompt = (
  55. "根据以下参考文档(父子文档检索,包含完整上下文),回答问题。\n\n"
  56. "参考文档:\n{}\n\n"
  57. "问题:{}\n\n"
  58. "请详细回答。".format(context, query)
  59. )
  60. return self._call_llm(prompt)