| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546 |
- # -*- coding: utf-8 -*-
- """Contextual Compression RAG - compress documents to extract relevant parts."""
- from .base import BaseRAG
- from loguru import logger
- class ContextualCompressionRAG(BaseRAG):
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
- def _compress_document(self, doc, query):
- prompt = (
- "提取以下文档中与问题相关的核心信息,去除不相关的内容。\n\n"
- "问题:{}\n\n"
- "文档内容:\n{}\n\n"
- "提取的核心信息:".format(query, doc.page_content)
- )
- compressed_content = self._call_llm(prompt)
- return compressed_content
- def retrieve(self, query, k=10):
- query_embedding = self.embedding_model.embed_query(query)
- initial_results = self.vector_store.similarity_search(query_embedding, k * 2)
-
- compressed_results = []
- for doc, score in initial_results:
- compressed_content = self._compress_document(doc, query)
- compressed_doc = type(doc)(
- page_content=compressed_content,
- metadata=dict(doc.metadata, is_compressed=True)
- )
- compressed_results.append((compressed_doc, score))
-
- logger.info("ContextualCompression: compressed {} documents".format(len(compressed_results)))
-
- compressed_results.sort(key=lambda x: x[1], reverse=True)
- return self._deduplicate_results(compressed_results, k)
- def generate(self, query, context):
- prompt = (
- "根据以下经过上下文压缩的参考文档,回答问题。\n\n"
- "参考文档:\n{}\n\n"
- "问题:{}\n\n"
- "请详细回答。".format(context, query)
- )
- return self._call_llm(prompt)
|