| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111 |
- # -*- coding: utf-8 -*-
- """Base RAG classes: RAGResult and BaseRAG."""
- from abc import ABC, abstractmethod
- import time
- from ..document_processor import Document
- from .dedup import content_dedup_key, deduplicate_ranked_results
- class RAGResult(object):
- """Result of a RAG query."""
- def __init__(self, answer="", retrieved_docs=None, latency_retrieval=0.0,
- latency_generation=0.0, latency_total=0.0, metadata=None):
- self.answer = answer
- self.retrieved_docs = retrieved_docs if retrieved_docs is not None else []
- self.latency_retrieval = latency_retrieval
- self.latency_generation = latency_generation
- self.latency_total = latency_total
- self.metadata = metadata if metadata is not None else {}
- def to_dict(self):
- return {
- "answer": self.answer,
- "num_docs_retrieved": len(self.retrieved_docs),
- "latency_retrieval": round(self.latency_retrieval, 3),
- "latency_generation": round(self.latency_generation, 3),
- "latency_total": round(self.latency_total, 3),
- "metadata": self.metadata or {},
- }
- class BaseRAG(ABC):
- """Abstract base class for all RAG methods."""
- def __init__(self, embedding_model=None, vector_store=None, llm_client=None, llm_model="gpt-4o", **kwargs):
- self.embedding_model = embedding_model
- self.vector_store = vector_store
- self.llm_client = llm_client
- self.llm_model = llm_model
- self.name = self.__class__.__name__
- def index_documents(self, documents):
- """Index documents using the embedding model and vector store."""
- texts = [doc.page_content for doc in documents]
- embeddings = self.embedding_model.embed_documents(texts)
- self.vector_store.add_documents(documents, embeddings)
- def _call_llm(self, prompt, system_prompt=None):
- """Call the LLM to generate a response."""
- messages = []
- if system_prompt:
- messages.append({"role": "system", "content": system_prompt})
- messages.append({"role": "user", "content": prompt})
- response = self.llm_client.chat.completions.create(
- model=self.llm_model,
- messages=messages,
- temperature=0.1,
- max_tokens=2048,
- )
- return response.choices[0].message.content
- def _format_context(self, docs):
- """Format retrieved documents into a context string."""
- context_parts = []
- for i, (doc, score) in enumerate(docs, 1):
- source = doc.metadata.get("source", "unknown")
- context_parts.append("[{}] (Score: {:.3f}, Source: {})\n{}".format(i, score, source, doc.page_content))
- return "\n\n---\n\n".join(context_parts)
- def _dedup_key(self, doc):
- """Return a content-based key for retrieval result deduplication."""
- return content_dedup_key(doc)
- def _deduplicate_results(self, results, k=None):
- """Deduplicate ranked retrieval results by normalized document content."""
- return deduplicate_ranked_results(results, k)
- @abstractmethod
- def retrieve(self, query, k=10):
- """Retrieve relevant documents for the query."""
- pass
- @abstractmethod
- def generate(self, query, context):
- """Generate an answer based on the context."""
- pass
- def query(self, query, k=10):
- """Full RAG pipeline: retrieve + generate."""
- start_total = time.time()
-
- t0 = time.time()
- docs = self.retrieve(query, k)
- retrieval_time = time.time() - t0
- context = self._format_context(docs)
- t1 = time.time()
- answer = self.generate(query, context)
- generation_time = time.time() - t1
- total_time = time.time() - start_total
- return RAGResult(
- answer=answer,
- retrieved_docs=docs,
- latency_retrieval=retrieval_time,
- latency_generation=generation_time,
- latency_total=total_time,
- metadata={"method": self.name, "num_context_tokens": len(context)},
- )
|