# -*- coding: utf-8 -*- """Base RAG classes: RAGResult and BaseRAG.""" from abc import ABC, abstractmethod import time from ..document_processor import Document from .dedup import content_dedup_key, deduplicate_ranked_results class RAGResult(object): """Result of a RAG query.""" def __init__(self, answer="", retrieved_docs=None, latency_retrieval=0.0, latency_generation=0.0, latency_total=0.0, metadata=None): self.answer = answer self.retrieved_docs = retrieved_docs if retrieved_docs is not None else [] self.latency_retrieval = latency_retrieval self.latency_generation = latency_generation self.latency_total = latency_total self.metadata = metadata if metadata is not None else {} def to_dict(self): return { "answer": self.answer, "num_docs_retrieved": len(self.retrieved_docs), "latency_retrieval": round(self.latency_retrieval, 3), "latency_generation": round(self.latency_generation, 3), "latency_total": round(self.latency_total, 3), "metadata": self.metadata or {}, } class BaseRAG(ABC): """Abstract base class for all RAG methods.""" def __init__(self, embedding_model=None, vector_store=None, llm_client=None, llm_model="gpt-4o", **kwargs): self.embedding_model = embedding_model self.vector_store = vector_store self.llm_client = llm_client self.llm_model = llm_model self.name = self.__class__.__name__ def index_documents(self, documents): """Index documents using the embedding model and vector store.""" texts = [doc.page_content for doc in documents] embeddings = self.embedding_model.embed_documents(texts) self.vector_store.add_documents(documents, embeddings) def _call_llm(self, prompt, system_prompt=None): """Call the LLM to generate a response.""" messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "user", "content": prompt}) response = self.llm_client.chat.completions.create( model=self.llm_model, messages=messages, temperature=0.1, max_tokens=2048, ) return response.choices[0].message.content def _format_context(self, docs): """Format retrieved documents into a context string.""" context_parts = [] for i, (doc, score) in enumerate(docs, 1): source = doc.metadata.get("source", "unknown") context_parts.append("[{}] (Score: {:.3f}, Source: {})\n{}".format(i, score, source, doc.page_content)) return "\n\n---\n\n".join(context_parts) def _dedup_key(self, doc): """Return a content-based key for retrieval result deduplication.""" return content_dedup_key(doc) def _deduplicate_results(self, results, k=None): """Deduplicate ranked retrieval results by normalized document content.""" return deduplicate_ranked_results(results, k) @abstractmethod def retrieve(self, query, k=10): """Retrieve relevant documents for the query.""" pass @abstractmethod def generate(self, query, context): """Generate an answer based on the context.""" pass def query(self, query, k=10): """Full RAG pipeline: retrieve + generate.""" start_total = time.time() t0 = time.time() docs = self.retrieve(query, k) retrieval_time = time.time() - t0 context = self._format_context(docs) t1 = time.time() answer = self.generate(query, context) generation_time = time.time() - t1 total_time = time.time() - start_total return RAGResult( answer=answer, retrieved_docs=docs, latency_retrieval=retrieval_time, latency_generation=generation_time, latency_total=total_time, metadata={"method": self.name, "num_context_tokens": len(context)}, )