base.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. # -*- coding: utf-8 -*-
  2. """Base RAG classes: RAGResult and BaseRAG."""
  3. from abc import ABC, abstractmethod
  4. import time
  5. from ..document_processor import Document
  6. from .dedup import content_dedup_key, deduplicate_ranked_results
  7. class RAGResult(object):
  8. """Result of a RAG query."""
  9. def __init__(self, answer="", retrieved_docs=None, latency_retrieval=0.0,
  10. latency_generation=0.0, latency_total=0.0, metadata=None):
  11. self.answer = answer
  12. self.retrieved_docs = retrieved_docs if retrieved_docs is not None else []
  13. self.latency_retrieval = latency_retrieval
  14. self.latency_generation = latency_generation
  15. self.latency_total = latency_total
  16. self.metadata = metadata if metadata is not None else {}
  17. def to_dict(self):
  18. return {
  19. "answer": self.answer,
  20. "num_docs_retrieved": len(self.retrieved_docs),
  21. "latency_retrieval": round(self.latency_retrieval, 3),
  22. "latency_generation": round(self.latency_generation, 3),
  23. "latency_total": round(self.latency_total, 3),
  24. "metadata": self.metadata or {},
  25. }
  26. class BaseRAG(ABC):
  27. """Abstract base class for all RAG methods."""
  28. def __init__(self, embedding_model=None, vector_store=None, llm_client=None, llm_model="gpt-4o", **kwargs):
  29. self.embedding_model = embedding_model
  30. self.vector_store = vector_store
  31. self.llm_client = llm_client
  32. self.llm_model = llm_model
  33. self.name = self.__class__.__name__
  34. def index_documents(self, documents):
  35. """Index documents using the embedding model and vector store."""
  36. texts = [doc.page_content for doc in documents]
  37. embeddings = self.embedding_model.embed_documents(texts)
  38. self.vector_store.add_documents(documents, embeddings)
  39. def _call_llm(self, prompt, system_prompt=None):
  40. """Call the LLM to generate a response."""
  41. messages = []
  42. if system_prompt:
  43. messages.append({"role": "system", "content": system_prompt})
  44. messages.append({"role": "user", "content": prompt})
  45. response = self.llm_client.chat.completions.create(
  46. model=self.llm_model,
  47. messages=messages,
  48. temperature=0.1,
  49. max_tokens=2048,
  50. )
  51. return response.choices[0].message.content
  52. def _format_context(self, docs):
  53. """Format retrieved documents into a context string."""
  54. context_parts = []
  55. for i, (doc, score) in enumerate(docs, 1):
  56. source = doc.metadata.get("source", "unknown")
  57. context_parts.append("[{}] (Score: {:.3f}, Source: {})\n{}".format(i, score, source, doc.page_content))
  58. return "\n\n---\n\n".join(context_parts)
  59. def _dedup_key(self, doc):
  60. """Return a content-based key for retrieval result deduplication."""
  61. return content_dedup_key(doc)
  62. def _deduplicate_results(self, results, k=None):
  63. """Deduplicate ranked retrieval results by normalized document content."""
  64. return deduplicate_ranked_results(results, k)
  65. @abstractmethod
  66. def retrieve(self, query, k=10):
  67. """Retrieve relevant documents for the query."""
  68. pass
  69. @abstractmethod
  70. def generate(self, query, context):
  71. """Generate an answer based on the context."""
  72. pass
  73. def query(self, query, k=10):
  74. """Full RAG pipeline: retrieve + generate."""
  75. start_total = time.time()
  76. t0 = time.time()
  77. docs = self.retrieve(query, k)
  78. retrieval_time = time.time() - t0
  79. context = self._format_context(docs)
  80. t1 = time.time()
  81. answer = self.generate(query, context)
  82. generation_time = time.time() - t1
  83. total_time = time.time() - start_total
  84. return RAGResult(
  85. answer=answer,
  86. retrieved_docs=docs,
  87. latency_retrieval=retrieval_time,
  88. latency_generation=generation_time,
  89. latency_total=total_time,
  90. metadata={"method": self.name, "num_context_tokens": len(context)},
  91. )