| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990 |
- # -*- coding: utf-8 -*-
- """RAPTOR RAG - Recursive Abstractive Processing for Tree-Organized Retrieval."""
- from .base import BaseRAG, Document
- from loguru import logger
- class RAPTORRAG(BaseRAG):
- def __init__(self, max_tree_depth=3, cluster_size=5, **kwargs):
- super().__init__(**kwargs)
- self.max_tree_depth = max_tree_depth
- self.cluster_size = cluster_size
- self.tree_nodes = []
- def _cluster_documents(self, docs, k):
- if len(docs) <= k:
- return [docs]
-
- embeddings = []
- for doc in docs:
- emb = self.embedding_model.embed_query(doc.page_content[:200])
- embeddings.append(emb)
-
- import numpy as np
- embeddings = np.array(embeddings)
-
- from sklearn.cluster import KMeans
- k = min(k, len(docs))
- kmeans = KMeans(n_clusters=k, random_state=42)
- labels = kmeans.fit_predict(embeddings)
-
- clusters = {}
- for i, label in enumerate(labels):
- if label not in clusters:
- clusters[label] = []
- clusters[label].append(docs[i])
-
- return list(clusters.values())
- def _summarize_cluster(self, docs):
- texts = "\n".join([doc.page_content[:300] for doc in docs[:5]])
- prompt = (
- "总结以下文档的核心主题,用一段话概括。\n\n"
- "文档内容:\n{}\n\n"
- "核心主题:".format(texts)
- )
- summary = self._call_llm(prompt)
- return summary
- def build_tree(self, documents):
- self.tree_nodes = []
- current_level = documents
-
- for depth in range(self.max_tree_depth):
- logger.info("RAPTOR building tree level {}".format(depth))
-
- clusters = self._cluster_documents(current_level, len(current_level) // self.cluster_size + 1)
- next_level = []
-
- for cluster in clusters:
- summary = self._summarize_cluster(cluster)
- summary_doc = Document(
- page_content=summary,
- metadata={"level": depth, "num_children": len(cluster)}
- )
- next_level.append(summary_doc)
-
- self.tree_nodes.extend(current_level)
- current_level = next_level
-
- if len(current_level) <= 1:
- break
-
- self.tree_nodes.extend(current_level)
-
- texts = [doc.page_content for doc in self.tree_nodes]
- embeddings = self.embedding_model.embed_documents(texts)
- self.vector_store.add_documents(self.tree_nodes, embeddings)
- def retrieve(self, query, k=10):
- query_embedding = self.embedding_model.embed_query(query)
- return self._deduplicate_results(self.vector_store.similarity_search(query_embedding, k), k)
- def generate(self, query, context):
- prompt = (
- "根据以下树状组织的参考文档,回答问题。\n\n"
- "参考文档:\n{}\n\n"
- "问题:{}\n\n"
- "请详细回答。".format(context, query)
- )
- return self._call_llm(prompt)
|