naive_rag.py 931 B

12345678910111213141516171819202122
  1. # -*- coding: utf-8 -*-
  2. """Naive RAG - basic semantic retrieval."""
  3. from .base import BaseRAG
  4. class NaiveRAG(BaseRAG):
  5. def __init__(self, retrieval_prompt_template=None, **kwargs):
  6. super().__init__(**kwargs)
  7. self.retrieval_prompt_template = retrieval_prompt_template or (
  8. "根据以下参考文档,回答问题。\n\n"
  9. "参考文档:\n{context}\n\n"
  10. "问题:{query}\n\n"
  11. "请详细回答,如果参考文档中没有相关信息,请说明无法从文档中找到答案。"
  12. )
  13. def retrieve(self, query, k=10):
  14. query_embedding = self.embedding_model.embed_query(query)
  15. return self._deduplicate_results(self.vector_store.similarity_search(query_embedding, k), k)
  16. def generate(self, query, context):
  17. prompt = self.retrieval_prompt_template.format(context=context, query=query)
  18. return self._call_llm(prompt)