| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849 |
- # -*- coding: utf-8 -*-
- """FLARE RAG - Forward-Looking Active REtrieval augmented generation."""
- from .base import BaseRAG
- from loguru import logger
- class FLARERAG(BaseRAG):
- def __init__(self, max_iterations=3, **kwargs):
- super().__init__(**kwargs)
- self.max_iterations = max_iterations
- def _generate_with_retrieval(self, query):
- prompt = "根据已知信息,逐步回答以下问题。如果某一步需要更多信息,请标记[需要检索]。\n\n问题:{}\n\n回答:".format(query)
-
- current_answer = ""
- for i in range(self.max_iterations):
- step_prompt = prompt + current_answer
- partial_answer = self._call_llm(step_prompt)
-
- if "[需要检索]" not in partial_answer:
- current_answer += partial_answer
- break
-
- sentences = partial_answer.split("。")
- new_sentences = []
- for sent in sentences:
- if "[需要检索]" in sent:
- retrieval_query = sent.replace("[需要检索]", "").strip()
- if retrieval_query:
- q_embedding = self.embedding_model.embed_query(retrieval_query)
- results = self.vector_store.similarity_search(q_embedding, 3)
- additional_context = "\n".join([doc.page_content[:200] for doc, _ in results])
- new_sentences.append(sent + "(补充信息:" + additional_context + ")")
- else:
- new_sentences.append(sent)
- else:
- new_sentences.append(sent)
-
- current_answer += "。".join(new_sentences) + "。"
- logger.info("FLARE iteration {}: retrieved additional context".format(i + 1))
-
- return current_answer
- def retrieve(self, query, k=10):
- query_embedding = self.embedding_model.embed_query(query)
- return self._deduplicate_results(self.vector_store.similarity_search(query_embedding, k), k)
- def generate(self, query, context):
- return self._generate_with_retrieval(query)
|