# -*- coding: utf-8 -*- """RAPTOR RAG - Recursive Abstractive Processing for Tree-Organized Retrieval.""" from .base import BaseRAG, Document from loguru import logger class RAPTORRAG(BaseRAG): def __init__(self, max_tree_depth=3, cluster_size=5, **kwargs): super().__init__(**kwargs) self.max_tree_depth = max_tree_depth self.cluster_size = cluster_size self.tree_nodes = [] def _cluster_documents(self, docs, k): if len(docs) <= k: return [docs] embeddings = [] for doc in docs: emb = self.embedding_model.embed_query(doc.page_content[:200]) embeddings.append(emb) import numpy as np embeddings = np.array(embeddings) from sklearn.cluster import KMeans k = min(k, len(docs)) kmeans = KMeans(n_clusters=k, random_state=42) labels = kmeans.fit_predict(embeddings) clusters = {} for i, label in enumerate(labels): if label not in clusters: clusters[label] = [] clusters[label].append(docs[i]) return list(clusters.values()) def _summarize_cluster(self, docs): texts = "\n".join([doc.page_content[:300] for doc in docs[:5]]) prompt = ( "总结以下文档的核心主题,用一段话概括。\n\n" "文档内容:\n{}\n\n" "核心主题:".format(texts) ) summary = self._call_llm(prompt) return summary def build_tree(self, documents): self.tree_nodes = [] current_level = documents for depth in range(self.max_tree_depth): logger.info("RAPTOR building tree level {}".format(depth)) clusters = self._cluster_documents(current_level, len(current_level) // self.cluster_size + 1) next_level = [] for cluster in clusters: summary = self._summarize_cluster(cluster) summary_doc = Document( page_content=summary, metadata={"level": depth, "num_children": len(cluster)} ) next_level.append(summary_doc) self.tree_nodes.extend(current_level) current_level = next_level if len(current_level) <= 1: break self.tree_nodes.extend(current_level) texts = [doc.page_content for doc in self.tree_nodes] embeddings = self.embedding_model.embed_documents(texts) self.vector_store.add_documents(self.tree_nodes, embeddings) def retrieve(self, query, k=10): query_embedding = self.embedding_model.embed_query(query) return self._deduplicate_results(self.vector_store.similarity_search(query_embedding, k), k) def generate(self, query, context): prompt = ( "根据以下树状组织的参考文档,回答问题。\n\n" "参考文档:\n{}\n\n" "问题:{}\n\n" "请详细回答。".format(context, query) ) return self._call_llm(prompt)