flare_rag.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. # -*- coding: utf-8 -*-
  2. """FLARE RAG - Forward-Looking Active REtrieval augmented generation."""
  3. from .base import BaseRAG
  4. from loguru import logger
  5. class FLARERAG(BaseRAG):
  6. def __init__(self, max_iterations=3, **kwargs):
  7. super().__init__(**kwargs)
  8. self.max_iterations = max_iterations
  9. def _generate_with_retrieval(self, query):
  10. prompt = "根据已知信息,逐步回答以下问题。如果某一步需要更多信息,请标记[需要检索]。\n\n问题:{}\n\n回答:".format(query)
  11. current_answer = ""
  12. for i in range(self.max_iterations):
  13. step_prompt = prompt + current_answer
  14. partial_answer = self._call_llm(step_prompt)
  15. if "[需要检索]" not in partial_answer:
  16. current_answer += partial_answer
  17. break
  18. sentences = partial_answer.split("。")
  19. new_sentences = []
  20. for sent in sentences:
  21. if "[需要检索]" in sent:
  22. retrieval_query = sent.replace("[需要检索]", "").strip()
  23. if retrieval_query:
  24. q_embedding = self.embedding_model.embed_query(retrieval_query)
  25. results = self.vector_store.similarity_search(q_embedding, 3)
  26. additional_context = "\n".join([doc.page_content[:200] for doc, _ in results])
  27. new_sentences.append(sent + "(补充信息:" + additional_context + ")")
  28. else:
  29. new_sentences.append(sent)
  30. else:
  31. new_sentences.append(sent)
  32. current_answer += "。".join(new_sentences) + "。"
  33. logger.info("FLARE iteration {}: retrieved additional context".format(i + 1))
  34. return current_answer
  35. def retrieve(self, query, k=10):
  36. query_embedding = self.embedding_model.embed_query(query)
  37. return self._deduplicate_results(self.vector_store.similarity_search(query_embedding, k), k)
  38. def generate(self, query, context):
  39. return self._generate_with_retrieval(query)