contextual_compression_rag.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. # -*- coding: utf-8 -*-
  2. """Contextual Compression RAG - compress documents to extract relevant parts."""
  3. from .base import BaseRAG
  4. from loguru import logger
  5. class ContextualCompressionRAG(BaseRAG):
  6. def __init__(self, **kwargs):
  7. super().__init__(**kwargs)
  8. def _compress_document(self, doc, query):
  9. prompt = (
  10. "提取以下文档中与问题相关的核心信息,去除不相关的内容。\n\n"
  11. "问题:{}\n\n"
  12. "文档内容:\n{}\n\n"
  13. "提取的核心信息:".format(query, doc.page_content)
  14. )
  15. compressed_content = self._call_llm(prompt)
  16. return compressed_content
  17. def retrieve(self, query, k=10):
  18. query_embedding = self.embedding_model.embed_query(query)
  19. initial_results = self.vector_store.similarity_search(query_embedding, k * 2)
  20. compressed_results = []
  21. for doc, score in initial_results:
  22. compressed_content = self._compress_document(doc, query)
  23. compressed_doc = type(doc)(
  24. page_content=compressed_content,
  25. metadata=dict(doc.metadata, is_compressed=True)
  26. )
  27. compressed_results.append((compressed_doc, score))
  28. logger.info("ContextualCompression: compressed {} documents".format(len(compressed_results)))
  29. compressed_results.sort(key=lambda x: x[1], reverse=True)
  30. return self._deduplicate_results(compressed_results, k)
  31. def generate(self, query, context):
  32. prompt = (
  33. "根据以下经过上下文压缩的参考文档,回答问题。\n\n"
  34. "参考文档:\n{}\n\n"
  35. "问题:{}\n\n"
  36. "请详细回答。".format(context, query)
  37. )
  38. return self._call_llm(prompt)