# -*- coding: utf-8 -*- """Contextual Compression RAG - compress documents to extract relevant parts.""" from .base import BaseRAG from loguru import logger class ContextualCompressionRAG(BaseRAG): def __init__(self, **kwargs): super().__init__(**kwargs) def _compress_document(self, doc, query): prompt = ( "提取以下文档中与问题相关的核心信息,去除不相关的内容。\n\n" "问题:{}\n\n" "文档内容:\n{}\n\n" "提取的核心信息:".format(query, doc.page_content) ) compressed_content = self._call_llm(prompt) return compressed_content def retrieve(self, query, k=10): query_embedding = self.embedding_model.embed_query(query) initial_results = self.vector_store.similarity_search(query_embedding, k * 2) compressed_results = [] for doc, score in initial_results: compressed_content = self._compress_document(doc, query) compressed_doc = type(doc)( page_content=compressed_content, metadata=dict(doc.metadata, is_compressed=True) ) compressed_results.append((compressed_doc, score)) logger.info("ContextualCompression: compressed {} documents".format(len(compressed_results))) compressed_results.sort(key=lambda x: x[1], reverse=True) return self._deduplicate_results(compressed_results, k) def generate(self, query, context): prompt = ( "根据以下经过上下文压缩的参考文档,回答问题。\n\n" "参考文档:\n{}\n\n" "问题:{}\n\n" "请详细回答。".format(context, query) ) return self._call_llm(prompt)