# -*- coding: utf-8 -*- """FLARE RAG - Forward-Looking Active REtrieval augmented generation.""" from .base import BaseRAG from loguru import logger class FLARERAG(BaseRAG): def __init__(self, max_iterations=3, **kwargs): super().__init__(**kwargs) self.max_iterations = max_iterations def _generate_with_retrieval(self, query): prompt = "根据已知信息,逐步回答以下问题。如果某一步需要更多信息,请标记[需要检索]。\n\n问题:{}\n\n回答:".format(query) current_answer = "" for i in range(self.max_iterations): step_prompt = prompt + current_answer partial_answer = self._call_llm(step_prompt) if "[需要检索]" not in partial_answer: current_answer += partial_answer break sentences = partial_answer.split("。") new_sentences = [] for sent in sentences: if "[需要检索]" in sent: retrieval_query = sent.replace("[需要检索]", "").strip() if retrieval_query: q_embedding = self.embedding_model.embed_query(retrieval_query) results = self.vector_store.similarity_search(q_embedding, 3) additional_context = "\n".join([doc.page_content[:200] for doc, _ in results]) new_sentences.append(sent + "(补充信息:" + additional_context + ")") else: new_sentences.append(sent) else: new_sentences.append(sent) current_answer += "。".join(new_sentences) + "。" logger.info("FLARE iteration {}: retrieved additional context".format(i + 1)) return current_answer def retrieve(self, query, k=10): query_embedding = self.embedding_model.embed_query(query) return self._deduplicate_results(self.vector_store.similarity_search(query_embedding, k), k) def generate(self, query, context): return self._generate_with_retrieval(query)