raptor_rag.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. # -*- coding: utf-8 -*-
  2. """RAPTOR RAG - Recursive Abstractive Processing for Tree-Organized Retrieval."""
  3. from .base import BaseRAG, Document
  4. from loguru import logger
  5. class RAPTORRAG(BaseRAG):
  6. def __init__(self, max_tree_depth=3, cluster_size=5, **kwargs):
  7. super().__init__(**kwargs)
  8. self.max_tree_depth = max_tree_depth
  9. self.cluster_size = cluster_size
  10. self.tree_nodes = []
  11. def _cluster_documents(self, docs, k):
  12. if len(docs) <= k:
  13. return [docs]
  14. embeddings = []
  15. for doc in docs:
  16. emb = self.embedding_model.embed_query(doc.page_content[:200])
  17. embeddings.append(emb)
  18. import numpy as np
  19. embeddings = np.array(embeddings)
  20. from sklearn.cluster import KMeans
  21. k = min(k, len(docs))
  22. kmeans = KMeans(n_clusters=k, random_state=42)
  23. labels = kmeans.fit_predict(embeddings)
  24. clusters = {}
  25. for i, label in enumerate(labels):
  26. if label not in clusters:
  27. clusters[label] = []
  28. clusters[label].append(docs[i])
  29. return list(clusters.values())
  30. def _summarize_cluster(self, docs):
  31. texts = "\n".join([doc.page_content[:300] for doc in docs[:5]])
  32. prompt = (
  33. "总结以下文档的核心主题,用一段话概括。\n\n"
  34. "文档内容:\n{}\n\n"
  35. "核心主题:".format(texts)
  36. )
  37. summary = self._call_llm(prompt)
  38. return summary
  39. def build_tree(self, documents):
  40. self.tree_nodes = []
  41. current_level = documents
  42. for depth in range(self.max_tree_depth):
  43. logger.info("RAPTOR building tree level {}".format(depth))
  44. clusters = self._cluster_documents(current_level, len(current_level) // self.cluster_size + 1)
  45. next_level = []
  46. for cluster in clusters:
  47. summary = self._summarize_cluster(cluster)
  48. summary_doc = Document(
  49. page_content=summary,
  50. metadata={"level": depth, "num_children": len(cluster)}
  51. )
  52. next_level.append(summary_doc)
  53. self.tree_nodes.extend(current_level)
  54. current_level = next_level
  55. if len(current_level) <= 1:
  56. break
  57. self.tree_nodes.extend(current_level)
  58. texts = [doc.page_content for doc in self.tree_nodes]
  59. embeddings = self.embedding_model.embed_documents(texts)
  60. self.vector_store.add_documents(self.tree_nodes, embeddings)
  61. def retrieve(self, query, k=10):
  62. query_embedding = self.embedding_model.embed_query(query)
  63. return self._deduplicate_results(self.vector_store.similarity_search(query_embedding, k), k)
  64. def generate(self, query, context):
  65. prompt = (
  66. "根据以下树状组织的参考文档,回答问题。\n\n"
  67. "参考文档:\n{}\n\n"
  68. "问题:{}\n\n"
  69. "请详细回答。".format(context, query)
  70. )
  71. return self._call_llm(prompt)