from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional, Tuple
import time
import re
import numpy as np
from loguru import logger
from .document_processor import Document
from .embedding_models import BaseEmbedding
from .vector_stores import BaseVectorStore
class RAGResult(object):
def __init__(self, answer="", retrieved_docs=None, latency_retrieval=0.0,
latency_generation=0.0, latency_total=0.0, metadata=None):
self.answer = answer
self.retrieved_docs = retrieved_docs if retrieved_docs is not None else []
self.latency_retrieval = latency_retrieval
self.latency_generation = latency_generation
self.latency_total = latency_total
self.metadata = metadata if metadata is not None else {}
def to_dict(self):
return {
"answer": self.answer,
"num_docs_retrieved": len(self.retrieved_docs),
"latency_retrieval": round(self.latency_retrieval, 3),
"latency_generation": round(self.latency_generation, 3),
"latency_total": round(self.latency_total, 3),
"metadata": self.metadata or {},
}
class BaseRAG(ABC):
def __init__(self, embedding_model=None, vector_store=None, llm_client=None, llm_model="gpt-4o", **kwargs):
self.embedding_model = embedding_model
self.vector_store = vector_store
self.llm_client = llm_client
self.llm_model = llm_model
self.name = self.__class__.__name__
def index_documents(self, documents):
texts = [doc.page_content for doc in documents]
embeddings = self.embedding_model.embed_documents(texts)
self.vector_store.add_documents(documents, embeddings)
def _call_llm(self, prompt, system_prompt=None):
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
response = self.llm_client.chat.completions.create(
model=self.llm_model,
messages=messages,
temperature=0.1,
max_tokens=2048,
)
return response.choices[0].message.content
def _format_context(self, docs):
context_parts = []
for i, (doc, score) in enumerate(docs, 1):
source = doc.metadata.get("source", "unknown")
context_parts.append("[{}] (Score: {:.3f}, Source: {})\n{}".format(i, score, source, doc.page_content))
return "\n\n---\n\n".join(context_parts)
@abstractmethod
def retrieve(self, query, k=10):
pass
@abstractmethod
def generate(self, query, context):
pass
def query(self, query, k=10):
start_total = time.time()
t0 = time.time()
docs = self.retrieve(query, k)
retrieval_time = time.time() - t0
context = self._format_context(docs)
t1 = time.time()
answer = self.generate(query, context)
generation_time = time.time() - t1
total_time = time.time() - start_total
return RAGResult(
answer=answer,
retrieved_docs=docs,
latency_retrieval=retrieval_time,
latency_generation=generation_time,
latency_total=total_time,
metadata={"method": self.name, "num_context_tokens": len(context)},
)
class NaiveRAG(BaseRAG):
def __init__(self, retrieval_prompt_template=None, **kwargs):
super().__init__(**kwargs)
self.retrieval_prompt_template = retrieval_prompt_template or (
"根据以下参考文档,回答问题。\n\n"
"参考文档:\n{context}\n\n"
"问题:{query}\n\n"
"请详细回答,如果参考文档中没有相关信息,请说明无法从文档中找到答案。"
)
def retrieve(self, query, k=10):
query_embedding = self.embedding_model.embed_query(query)
return self.vector_store.similarity_search(query_embedding, k)
def generate(self, query, context):
prompt = self.retrieval_prompt_template.format(context=context, query=query)
return self._call_llm(prompt)
class RerankRAG(BaseRAG):
def __init__(self, rerank_model=None, rerank_top_k=5, **kwargs):
super().__init__(**kwargs)
self.rerank_model = rerank_model
self.rerank_top_k = rerank_top_k
self.initial_k = 20
def retrieve(self, query, k=10):
query_embedding = self.embedding_model.embed_query(query)
initial_docs = self.vector_store.similarity_search(query_embedding, self.initial_k)
if self.rerank_model and len(initial_docs) > 0:
texts = [doc.page_content for doc, _ in initial_docs]
pairs = [(query, text) for text in texts]
scores = self.rerank_model.compute_score(pairs)
if isinstance(scores, (int, float)):
scores = [scores]
reranked = list(zip(initial_docs, scores))
reranked.sort(key=lambda x: x[1], reverse=True)
return [(doc, float(score)) for (doc, _), score in reranked[:k]]
return initial_docs[:k]
def generate(self, query, context):
prompt = (
"根据以下经过重排序的参考文档,回答问题。\n\n"
"参考文档:\n{}\n\n"
"问题:{}\n\n"
"请详细回答。".format(context, query)
)
return self._call_llm(prompt)
class ParentDocumentRAG(BaseRAG):
def __init__(self, parent_chunk_size=1500, **kwargs):
super().__init__(**kwargs)
self.parent_chunk_size = parent_chunk_size
self.parent_docs = []
self.child_to_parent = {}
def index_documents(self, documents):
self.parent_docs = []
self.child_to_parent = {}
for i, doc in enumerate(documents):
words = doc.page_content.split()
for j in range(0, len(words), self.parent_chunk_size):
parent_text = " ".join(words[j:j + self.parent_chunk_size])
parent_doc = Document(
page_content=parent_text,
metadata=dict(doc.metadata, chunk_index=j, is_parent=True)
)
self.parent_docs.append(parent_doc)
child_texts = []
for doc in documents:
words = doc.page_content.split()
for j in range(0, len(words), self.parent_chunk_size):
chunk_words = words[j:j + self.parent_chunk_size]
for k in range(0, len(chunk_words), 512):
child_text = " ".join(chunk_words[k:k + 512])
child_id = len(child_texts)
child_texts.append(child_text)
self.child_to_parent[child_id] = child_id // 3
if child_texts:
embeddings = self.embedding_model.embed_documents(child_texts)
child_docs = [
Document(page_content=text, metadata={"is_parent": False})
for text in child_texts
]
self.vector_store.add_documents(child_docs, embeddings)
logger.info("ParentDocumentRAG: {} parents, {} children".format(len(self.parent_docs), len(child_texts)))
def retrieve(self, query, k=10):
query_embedding = self.embedding_model.embed_query(query)
child_results = self.vector_store.similarity_search(query_embedding, k * 2)
parent_map = {}
for child_doc, score in child_results:
for i, parent_doc in enumerate(self.parent_docs):
if child_doc.page_content[:50] in parent_doc.page_content:
if i not in parent_map or score > parent_map[i]:
parent_map[i] = (parent_doc, score)
break
parent_results = sorted(parent_map.values(), key=lambda x: x[1], reverse=True)
return parent_results[:k]
def generate(self, query, context):
prompt = (
"根据以下参考文档(父子文档检索,包含完整上下文),回答问题。\n\n"
"参考文档:\n{}\n\n"
"问题:{}\n\n"
"请详细回答。".format(context, query)
)
return self._call_llm(prompt)
class LLMFilterRAG(BaseRAG):
def __init__(self, filter_threshold=0.5, **kwargs):
super().__init__(**kwargs)
self.filter_threshold = filter_threshold
def _score_relevance(self, query, doc):
prompt = (
"评估以下文档与问题的相关性,给出0-1之间的分数。只返回分数数字。\n\n"
"问题:{}\n\n"
"文档:{}\n\n"
"相关性分数:".format(query, doc.page_content[:300])
)
try:
response = self._call_llm(prompt).strip()
return float(response)
except:
return 0.5
def retrieve(self, query, k=10):
query_embedding = self.embedding_model.embed_query(query)
initial_results = self.vector_store.similarity_search(query_embedding, k * 3)
filtered_results = []
for doc, score in initial_results:
relevance = self._score_relevance(query, doc)
if relevance >= self.filter_threshold:
filtered_results.append((doc, score * relevance))
filtered_results.sort(key=lambda x: x[1], reverse=True)
return filtered_results[:k]
def generate(self, query, context):
prompt = (
"根据以下经过LLM精选的参考文档,回答问题。\n\n"
"参考文档:\n{}\n\n"
"问题:{}\n\n"
"请详细回答。".format(context, query)
)
return self._call_llm(prompt)
class QueryRoutingRAG(BaseRAG):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def _classify_query(self, query):
categories = [
"budget", "deadline", "qualification", "evaluation",
"payment", "warranty", "delivery", "contact", "scope"
]
categories_str = ", ".join(categories)
prompt = (
"将以下问题分类到以下类别之一:{}\n\n"
"问题:{}\n\n"
"类别:".format(categories_str, query)
)
return self._call_llm(prompt).strip().lower()
def retrieve(self, query, k=10):
query_embedding = self.embedding_model.embed_query(query)
semantic_results = self.vector_store.similarity_search(query_embedding, k)
category = self._classify_query(query)
logger.info("QueryRouting: category={}".format(category))
category_boost = {}
for doc, score in semantic_results:
content_lower = doc.page_content.lower()
category_keywords = {
"budget": ["预算", "金额", "价格", "费用", "报价"],
"deadline": ["截止", "时间", "日期", "开标"],
"qualification": ["资格", "要求", "证书", "业绩"],
"evaluation": ["评标", "评价", "分数", "方法"],
"payment": ["付款", "结算", "进度", "保证金"],
"warranty": ["质保", "维修", "售后", "服务"],
"delivery": ["交货", "工期", "交付", "地点"],
"contact": ["联系人", "电话", "邮箱"],
"scope": ["范围", "内容", "清单", "设备"],
}
keywords = category_keywords.get(category, [])
boost = sum(1 for kw in keywords if kw in content_lower) * 0.1
category_boost[id(doc)] = boost
enhanced_results = []
for doc, score in semantic_results:
boost = category_boost.get(id(doc), 0)
enhanced_results.append((doc, score + boost))
enhanced_results.sort(key=lambda x: x[1], reverse=True)
return enhanced_results[:k]
def generate(self, query, context):
prompt = (
"根据以下经过查询路由的参考文档,回答问题。\n\n"
"参考文档:\n{}\n\n"
"问题:{}\n\n"
"请详细回答。".format(context, query)
)
return self._call_llm(prompt)
class MetadataFilterRAG(BaseRAG):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def retrieve(self, query, k=10, metadata_filter=None):
query_embedding = self.embedding_model.embed_query(query)
all_results = self.vector_store.similarity_search(query_embedding, k * 3)
if metadata_filter:
filtered = []
for doc, score in all_results:
match = all(
doc.metadata.get(key) == value
for key, value in metadata_filter.items()
)
if match:
filtered.append((doc, score))
return filtered[:k]
return all_results[:k]
def generate(self, query, context):
prompt = (
"根据以下经过元数据筛选的参考文档,回答问题。\n\n"
"参考文档:\n{}\n\n"
"问题:{}\n\n"
"请详细回答。".format(context, query)
)
return self._call_llm(prompt)
class AdaptiveRAG(BaseRAG):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def _determine_strategy(self, query):
prompt = (
"根据问题类型,选择最合适的检索策略。\n"
"可选策略:\n"
"- semantic: 语义检索,适合概念性问题\n"
"- keyword: 关键词检索,适合精确实体匹配\n"
"- multi: 多路检索,适合复杂综合问题\n\n"
"问题:{}\n\n"
"策略名称(只返回name):".format(query)
)
return self._call_llm(prompt).strip().lower()
def retrieve(self, query, k=10):
strategy = self._determine_strategy(query)
logger.info("AdaptiveRAG: strategy={}".format(strategy))
if strategy in ["keyword", "exact"]:
try:
from rank_bm25 import BM25Okapi
except:
BM25Okapi = None
if BM25Okapi:
all_docs = self.vector_store.documents if hasattr(self.vector_store, 'documents') else []
if all_docs:
texts = [doc.page_content.split() for doc in all_docs]
bm25 = BM25Okapi(texts)
scores = bm25.get_scores(query.split())
results = []
for i, score in enumerate(scores):
if score > 0:
results.append((all_docs[i], float(score)))
results.sort(key=lambda x: x[1], reverse=True)
return results[:k]
query_embedding = self.embedding_model.embed_query(query)
semantic_results = self.vector_store.similarity_search(query_embedding, k)
if strategy in ["multi", "comprehensive"]:
all_results = {}
for doc, score in semantic_results:
doc_id = doc.page_content[:100]
all_results[doc_id] = (doc, score)
try:
from rank_bm25 import BM25Okapi
all_docs = self.vector_store.documents if hasattr(self.vector_store, 'documents') else []
if all_docs:
texts = [doc.page_content.split() for doc in all_docs]
bm25 = BM25Okapi(texts)
bm25_scores = bm25.get_scores(query.split())
max_bm25 = max(bm25_scores) if max(bm25_scores) > 0 else 1.0
for i, doc in enumerate(all_docs):
doc_id = doc.page_content[:100]
norm_bm25 = bm25_scores[i] / max_bm25
if doc_id in all_results:
all_results[doc_id] = (doc, all_results[doc_id][1] * 0.6 + norm_bm25 * 0.4)
else:
all_results[doc_id] = (doc, norm_bm25)
except:
pass
sorted_results = sorted(all_results.values(), key=lambda x: x[1], reverse=True)
return sorted_results[:k]
return semantic_results[:k]
def generate(self, query, context):
prompt = (
"根据以下经过自适应检索的参考文档,回答问题。\n\n"
"参考文档:\n{}\n\n"
"问题:{}\n\n"
"请详细回答。".format(context, query)
)
return self._call_llm(prompt)
class HybridSearchRAG(BaseRAG):
def __init__(self, bm25_index=None, semantic_weight=0.5, **kwargs):
super().__init__(**kwargs)
self.bm25_index = bm25_index
self.semantic_weight = semantic_weight
self._all_texts = []
def index_documents(self, documents):
super().index_documents(documents)
if self.bm25_index is not None:
from rank_bm25 import BM25Okapi
self._all_texts = [doc.page_content.split() for doc in documents]
self.bm25_index = BM25Okapi(self._all_texts)
def retrieve(self, query, k=10):
query_embedding = self.embedding_model.embed_query(query)
semantic_results = self.vector_store.similarity_search(query_embedding, k * 2)
if self.bm25_index and self._all_texts:
from rank_bm25 import BM25Okapi
query_tokens = query.split()
bm25_scores = self.bm25_index.get_scores(query_tokens)
max_bm25 = max(bm25_scores) if bm25_scores.max() > 0 else 1.0
normalized_bm25 = bm25_scores / max_bm25
combined = []
for i, (doc, sem_score) in enumerate(semantic_results):
bm25_score = normalized_bm25[i] if i < len(normalized_bm25) else 0.0
combined_score = (
self.semantic_weight * sem_score +
(1 - self.semantic_weight) * bm25_score
)
combined.append((doc, combined_score))
combined.sort(key=lambda x: x[1], reverse=True)
return combined[:k]
return semantic_results[:k]
def generate(self, query, context):
prompt = (
"根据以下参考文档(结合语义搜索和BM25关键词搜索),回答问题。\n\n"
"参考文档:\n{}\n\n"
"问题:{}\n\n"
"请详细回答。".format(context, query)
)
return self._call_llm(prompt)
class MultiQueryRAG(BaseRAG):
def __init__(self, num_queries=3, **kwargs):
super().__init__(**kwargs)
self.num_queries = num_queries
def _generate_queries(self, query):
prompt = (
"请将以下问题扩展为{}个不同角度的问题,"
"用于检索更全面的信息。每行一个问题。\n\n"
"原问题:{}\n\n"
"扩展问题:".format(self.num_queries, query)
)
response = self._call_llm(prompt)
queries = [q.strip() for q in response.strip().split("\n") if q.strip()]
queries.insert(0, query)
return queries[:self.num_queries + 1]
def retrieve(self, query, k=10):
queries = self._generate_queries(query)
all_docs = {}
for q in queries:
q_embedding = self.embedding_model.embed_query(q)
results = self.vector_store.similarity_search(q_embedding, k)
for doc, score in results:
doc_id = doc.page_content[:100]
if doc_id not in all_docs or score > all_docs[doc_id][1]:
all_docs[doc_id] = (doc, score)
sorted_docs = sorted(all_docs.values(), key=lambda x: x[1], reverse=True)
return sorted_docs[:k]
def generate(self, query, context):
prompt = (
"根据以下参考文档,回答问题。\n\n"
"参考文档:\n{}\n\n"
"问题:{}\n\n"
"请综合多个查询角度的检索结果,详细回答。".format(context, query)
)
return self._call_llm(prompt)
class HyDERAG(BaseRAG):
def __init__(self, num_hypotheses=3, **kwargs):
super().__init__(**kwargs)
self.num_hypotheses = num_hypotheses
def _generate_hypothetical_docs(self, query):
prompt = (
"假设你是一个招投标专家,请根据以下问题,"
"生成{}个可能包含答案的假想文档段落。"
"每个段落用''和''分隔。\n\n"
"问题:{}\n\n"
"假想文档:".format(self.num_hypotheses, query)
)
response = self._call_llm(prompt)
docs = re.findall(r'(.*?)', response, re.DOTALL)
if not docs:
docs = [response]
return docs[:self.num_hypotheses]
def retrieve(self, query, k=10):
hypothetical_docs = self._generate_hypothetical_docs(query)
all_results = {}
for hypo_doc in hypothetical_docs:
hypo_embedding = self.embedding_model.embed_query(hypo_doc)
results = self.vector_store.similarity_search(hypo_embedding, k)
for doc, score in results:
doc_id = doc.page_content[:100]
if doc_id not in all_results or score > all_results[doc_id][1]:
all_results[doc_id] = (doc, score)
query_embedding = self.embedding_model.embed_query(query)
direct_results = self.vector_store.similarity_search(query_embedding, k)
for doc, score in direct_results:
doc_id = doc.page_content[:100]
if doc_id not in all_results or score > all_results[doc_id][1]:
all_results[doc_id] = (doc, score)
sorted_docs = sorted(all_results.values(), key=lambda x: x[1], reverse=True)
return sorted_docs[:k]
def generate(self, query, context):
prompt = (
"根据以下参考文档,回答问题。\n\n"
"参考文档:\n{}\n\n"
"问题:{}\n\n"
"请详细回答。".format(context, query)
)
return self._call_llm(prompt)
class StepBackRAG(BaseRAG):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def _generate_step_back_query(self, query):
prompt = (
"请将以下具体问题抽象为一个更通用的高层次问题,"
"用于查找相关背景知识。只返回抽象后的问题。\n\n"
"具体问题:{}\n\n"
"高层次问题:".format(query)
)
return self._call_llm(prompt).strip()
def retrieve(self, query, k=10):
step_back_query = self._generate_step_back_query(query)
logger.info("StepBack query: {}".format(step_back_query))
query_embedding = self.embedding_model.embed_query(query)
direct_results = self.vector_store.similarity_search(query_embedding, k // 2)
step_back_embedding = self.embedding_model.embed_query(step_back_query)
step_back_results = self.vector_store.similarity_search(step_back_embedding, k // 2)
all_docs = {}
for doc, score in direct_results + step_back_results:
doc_id = doc.page_content[:100]
if doc_id not in all_docs or score > all_docs[doc_id][1]:
all_docs[doc_id] = (doc, score)
sorted_docs = sorted(all_docs.values(), key=lambda x: x[1], reverse=True)
return sorted_docs[:k]
def generate(self, query, context):
prompt = (
"结合以下参考文档(包含具体问题和抽象问题的检索结果),"
"回答问题。\n\n"
"参考文档:\n{}\n\n"
"问题:{}\n\n"
"请详细回答。".format(context, query)
)
return self._call_llm(prompt)
class ContextualCompressionRAG(BaseRAG):
def __init__(self, compression_llm=None, **kwargs):
super().__init__(**kwargs)
self.compression_llm = compression_llm or llm_client
def _compress_doc(self, query, doc):
prompt = (
"请压缩以下文档片段,仅保留与问题相关的信息。\n\n"
"问题:{}\n\n"
"文档:{}\n\n"
"压缩后的内容:".format(query, doc.page_content)
)
messages = []
messages.append({"role": "user", "content": prompt})
response = self.compression_llm.chat.completions.create(
model=self.llm_model,
messages=messages,
temperature=0.0,
max_tokens=512,
)
return response.choices[0].message.content
def retrieve(self, query, k=10):
query_embedding = self.embedding_model.embed_query(query)
initial_results = self.vector_store.similarity_search(query_embedding, k * 2)
compressed_results = []
for doc, score in initial_results:
compressed_text = self._compress_doc(query, doc)
compressed_doc = Document(page_content=compressed_text, metadata=doc.metadata)
compressed_results.append((compressed_doc, score))
return compressed_results[:k]
def generate(self, query, context):
prompt = (
"根据以下经过上下文压缩的参考文档,回答问题。\n\n"
"参考文档:\n{}\n\n"
"问题:{}\n\n"
"请详细回答。".format(context, query)
)
return self._call_llm(prompt)
class SelfRAG(BaseRAG):
def __init__(self, relevance_threshold=0.5, support_threshold=0.5, retrieval_threshold=0.6, critic_model=None, **kwargs):
super().__init__(**kwargs)
self.relevance_threshold = relevance_threshold
self.support_threshold = support_threshold
self.retrieval_threshold = retrieval_threshold
self.critic_model = critic_model
def _is_retrieval_needed(self, query):
prompt = (
"判断以下问题是否需要检索外部知识才能回答。"
"只需回答'是'或'否'。\n\n"
"问题:{}\n\n"
"是否需要检索:".format(query)
)
response = self._call_llm(prompt).strip().lower()
return "是" in response or "yes" in response
def _evaluate_relevance(self, query, doc):
prompt = (
"评估以下文档片段与问题的相关性,给出0-1之间的分数。\n\n"
"问题:{}\n\n"
"文档:{}\n\n"
"相关性分数:".format(query, doc.page_content)
)
response = self._call_llm(prompt).strip()
try:
return float(response)
except:
return 0.5
def _evaluate_support(self, query, answer, doc):
prompt = (
"评估以下回答是否得到了文档的支持,给出0-1之间的分数。\n\n"
"文档:{}\n\n"
"回答:{}\n\n"
"支持分数:".format(doc.page_content, answer)
)
response = self._call_llm(prompt).strip()
try:
return float(response)
except:
return 0.5
def _evaluate_usefulness(self, query, answer):
prompt = (
"评估以下回答对于问题的有用性,给出0-1之间的分数。\n\n"
"问题:{}\n\n"
"回答:{}\n\n"
"有用性分数:".format(query, answer)
)
response = self._call_llm(prompt).strip()
try:
return float(response)
except:
return 0.5
def retrieve(self, query, k=10):
if not self._is_retrieval_needed(query):
return []
query_embedding = self.embedding_model.embed_query(query)
initial_results = self.vector_store.similarity_search(query_embedding, k * 2)
filtered_results = []
for doc, score in initial_results:
relevance = self._evaluate_relevance(query, doc)
if relevance >= self.relevance_threshold:
filtered_results.append((doc, relevance * score))
filtered_results.sort(key=lambda x: x[1], reverse=True)
return filtered_results[:k]
def generate(self, query, context):
prompt = (
"根据以下经过自反思筛选的参考文档,回答问题。\n\n"
"参考文档:\n{}\n\n"
"问题:{}\n\n"
"请确保回答有充分的文档支持,详细回答。".format(context, query)
)
return self._call_llm(prompt)
class CorrectiveRAG(BaseRAG):
def __init__(self, correctness_threshold=0.6, **kwargs):
super().__init__(**kwargs)
self.correctness_threshold = correctness_threshold
self.web_search_results = []
def _evaluate_correctness(self, query, docs):
if not docs:
return 0.0
context = "\n\n".join([doc.page_content for doc, _ in docs[:5]])
prompt = (
"评估以下文档集合是否能够正确回答问题,给出0-1之间的分数。\n\n"
"问题:{}\n\n"
"文档集合:\n{}\n\n"
"正确性分数:".format(query, context)
)
response = self._call_llm(prompt).strip()
try:
return float(response)
except:
return 0.5
def _web_search(self, query):
return "[Web search results for: {}] - Simulated external knowledge".format(query)
def retrieve(self, query, k=10):
query_embedding = self.embedding_model.embed_query(query)
initial_results = self.vector_store.similarity_search(query_embedding, k)
correctness_score = self._evaluate_correctness(query, initial_results)
if correctness_score < self.correctness_threshold:
logger.info("CRAG: Correctness score {} below threshold, adding web search".format(correctness_score))
web_result = self._web_search(query)
web_doc = Document(
page_content=web_result,
metadata={"source": "web_search"}
)
initial_results.insert(0, (web_doc, 0.8))
return initial_results
def generate(self, query, context):
prompt = (
"根据以下参考文档(可能包含外部搜索结果),回答问题。\n\n"
"参考文档:\n{}\n\n"
"问题:{}\n\n"
"请详细回答。如果包含外部搜索结果,请注明。".format(context, query)
)
return self._call_llm(prompt)
class FLARERAG(BaseRAG):
def __init__(self, max_iterations=3, **kwargs):
super().__init__(**kwargs)
self.max_iterations = max_iterations
def _need_more_retrieval(self, query, current_answer):
prompt = (
"基于当前已有的信息,判断是否还需要更多检索才能完整回答问题。\n\n"
"问题:{}\n\n"
"当前已有信息的回答:{}\n\n"
"是否需要更多检索(只回答是/否):".format(query, current_answer)
)
response = self._call_llm(prompt).strip().lower()
return "是" in response or "yes" in response
def _generate_next_query(self, query, current_answer):
prompt = (
"基于当前回答的不足,生成一个新的查询来补充信息。\n\n"
"原问题:{}\n\n"
"当前回答:{}\n\n"
"新查询:".format(query, current_answer)
)
return self._call_llm(prompt).strip()
def retrieve(self, query, k=10):
all_docs = {}
current_query = query
current_answer = ""
for iteration in range(self.max_iterations):
query_embedding = self.embedding_model.embed_query(current_query)
results = self.vector_store.similarity_search(query_embedding, k)
for doc, score in results:
doc_id = doc.page_content[:100]
if doc_id not in all_docs or score > all_docs[doc_id][1]:
all_docs[doc_id] = (doc, score)
context = self._format_context(list(all_docs.values()))
current_answer = self._generate_partial_answer(query, context)
if not self._need_more_retrieval(query, current_answer):
break
current_query = self._generate_next_query(query, current_answer)
logger.info("FLARE iteration {}, new query: {}".format(iteration + 1, current_query))
sorted_docs = sorted(all_docs.values(), key=lambda x: x[1], reverse=True)
return sorted_docs[:k]
def _generate_partial_answer(self, query, context):
prompt = (
"根据以下参考文档,给出问题的回答。\n\n"
"参考文档:\n{}\n\n"
"问题:{}\n\n"
"回答:".format(context, query)
)
return self._call_llm(prompt)
def generate(self, query, context):
return self._generate_partial_answer(query, context)
class RAPTORRAG(BaseRAG):
def __init__(self, max_clusters=50, summary_length=256, num_tree_levels=2, **kwargs):
super().__init__(**kwargs)
self.max_clusters = max_clusters
self.summary_length = summary_length
self.num_tree_levels = num_tree_levels
self.hierarchical_docs = []
def _summarize_texts(self, texts):
summaries = []
batch_size = 5
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
combined = "\n\n".join(batch)
prompt = (
"请用一段话总结以下内容的核心要点({}字以内):\n\n"
"{}\n\n"
"总结:".format(self.summary_length, combined)
)
summary = self._call_llm(prompt)
summaries.append(summary)
return summaries
def _cluster_embeddings(self, embeddings):
from sklearn.cluster import AgglomerativeClustering
emb_np = np.array(embeddings)
n_clusters = min(self.max_clusters, len(embeddings))
if n_clusters < 2:
return [list(range(len(embeddings)))]
clustering = AgglomerativeClustering(n_clusters=n_clusters)
labels = clustering.fit_predict(emb_np)
clusters = {}
for i, label in enumerate(labels):
if label not in clusters:
clusters[label] = []
clusters[label].append(i)
return list(clusters.values())
def build_tree(self, documents):
self.index_documents(documents)
current_level_docs = documents
for level in range(self.num_tree_levels):
texts = [doc.page_content for doc in current_level_docs]
embeddings = self.embedding_model.embed_documents(texts)
clusters = self._cluster_embeddings(embeddings)
summary_docs = []
for cluster in clusters:
cluster_texts = [texts[i] for i in cluster]
summaries = self._summarize_texts(cluster_texts)
for summary in summaries:
summary_docs.append(Document(
page_content=summary,
metadata={"level": level + 1, "num_source_docs": len(cluster)}
))
self.hierarchical_docs.extend(summary_docs)
if summary_docs:
summary_embeddings = self.embedding_model.embed_documents([d.page_content for d in summary_docs])
self.vector_store.add_documents(summary_docs, summary_embeddings)
current_level_docs = summary_docs
logger.info("RAPTOR level {}: {} summaries created".format(level + 1, len(summary_docs)))
def retrieve(self, query, k=10):
query_embedding = self.embedding_model.embed_query(query)
all_results = []
all_results.extend(self.vector_store.similarity_search(query_embedding, k))
sorted_results = sorted(all_results, key=lambda x: x[1], reverse=True)
return sorted_results[:k]
def generate(self, query, context):
prompt = (
"根据以下参考文档(包含多层次摘要信息),回答问题。\n\n"
"参考文档:\n{}\n\n"
"问题:{}\n\n"
"请详细回答。".format(context, query)
)
return self._call_llm(prompt)
class EnsembleRAG(BaseRAG):
def __init__(self, methods=None, **kwargs):
super().__init__(**kwargs)
self.methods = methods or ["naive", "hybrid", "multi_query"]
def retrieve(self, query, k=10):
all_docs = {}
query_embedding = self.embedding_model.embed_query(query)
base_results = self.vector_store.similarity_search(query_embedding, k * 2)
for doc, score in base_results:
doc_id = doc.page_content[:100]
all_docs[doc_id] = (doc, score)
try:
from rank_bm25 import BM25Okapi
all_texts = [doc.page_content.split() for doc, _ in base_results]
bm25 = BM25Okapi(all_texts)
query_tokens = query.split()
bm25_scores = bm25.get_scores(query_tokens)
max_bm25 = max(bm25_scores) if bm25_scores.max() > 0 else 1.0
for i, (doc, _) in enumerate(base_results):
doc_id = doc.page_content[:100]
bm25_score = bm25_scores[i] / max_bm25
all_docs[doc_id] = (doc, all_docs[doc_id][1] * 0.5 + bm25_score * 0.5)
except:
pass
sorted_docs = sorted(all_docs.values(), key=lambda x: x[1], reverse=True)
return sorted_docs[:k]
def generate(self, query, context):
prompt = (
"根据以下参考文档(通过集成多种检索方法获取),回答问题。\n\n"
"参考文档:\n{}\n\n"
"问题:{}\n\n"
"请综合多种检索策略的结果,详细回答。".format(context, query)
)
return self._call_llm(prompt)
class BidFieldExtractionRAG(BaseRAG):
def __init__(self, fields=None, **kwargs):
super().__init__(**kwargs)
self.fields = fields or [
"project_name", "project_code", "budget_amount", "currency",
"bid_deadline", "bid_open_time", "bid_location",
"purchaser_name", "purchaser_contact", "purchaser_phone",
"agency_name", "agency_contact", "agency_phone",
"qualification_requirements", "bid_bond_amount",
"performance_bond_amount", "warranty_period",
"delivery_time", "delivery_location", "payment_terms",
"evaluation_method", "scope_of_work"
]
def retrieve(self, query, k=10):
all_docs = {}
for field in self.fields:
field_query = "{} {}".format(query, field)
field_embedding = self.embedding_model.embed_query(field_query)
results = self.vector_store.similarity_search(field_embedding, k // 3)
for doc, score in results:
doc_id = doc.page_content[:100]
if doc_id not in all_docs or score > all_docs[doc_id][1]:
all_docs[doc_id] = (doc, score)
query_embedding = self.embedding_model.embed_query(query)
direct_results = self.vector_store.similarity_search(query_embedding, k)
for doc, score in direct_results:
doc_id = doc.page_content[:100]
if doc_id not in all_docs or score > all_docs[doc_id][1]:
all_docs[doc_id] = (doc, score)
sorted_docs = sorted(all_docs.values(), key=lambda x: x[1], reverse=True)
return sorted_docs[:k]
def generate(self, query, context):
fields_str = "\n".join(["- {}".format(f) for f in self.fields])
prompt = """你是一个招投标领域的专家。请根据提供的文档内容,提取以下字段信息:
{}
文档内容:
{}
请以JSON格式返回提取结果。如果某个字段无法从文档中提取,请返回null。
JSON格式示例:
{{
"project_name": "项目名称",
"budget_amount": 1000000,
...
}}""".format(fields_str, context)
return self._call_llm(prompt)
class TableAwareRAG(BaseRAG):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def _extract_table_info(self, doc):
content = doc.page_content
table_pattern = r'(\|.*\|[\r\n]+)'
tables = re.findall(table_pattern, content)
if tables:
return "[TABLE DETECTED]\n{}".format(''.join(tables))
numbered_pattern = r'(\d+[\.、].*?(?=\d+[\.、]|$))'
numbered_items = re.findall(numbered_pattern, content, re.DOTALL)
if numbered_items:
return "[STRUCTURED LIST DETECTED]\n{}".format(''.join(numbered_items))
return content
def retrieve(self, query, k=10):
query_embedding = self.embedding_model.embed_query(query)
initial_results = self.vector_store.similarity_search(query_embedding, k * 2)
enhanced_results = []
for doc, score in initial_results:
enhanced_content = self._extract_table_info(doc)
enhanced_doc = Document(
page_content=enhanced_content,
metadata=dict(doc.metadata, has_table="[TABLE" in enhanced_content)
)
table_bonus = 0.1 if "[TABLE" in enhanced_content else 0.0
enhanced_results.append((enhanced_doc, score + table_bonus))
enhanced_results.sort(key=lambda x: x[1], reverse=True)
return enhanced_results[:k]
def generate(self, query, context):
prompt = (
"根据以下参考文档(包含表格和结构化数据),回答问题。\n\n"
"参考文档:\n{}\n\n"
"问题:{}\n\n"
"请特别注意表格和结构化数据中的信息,详细回答。".format(context, query)
)
return self._call_llm(prompt)
class GraphRAG(BaseRAG):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.graph = {}
def _extract_entities(self, text):
entities = re.findall(r'[一-龥]{2,10}(?:公司|单位|招标|投标|项目|金额|时间)', text)
return list(set(entities))
def build_graph(self, documents):
self.index_documents(documents)
for doc in documents:
entities = self._extract_entities(doc.page_content)
for entity in entities:
if entity not in self.graph:
self.graph[entity] = []
self.graph[entity].append(doc)
logger.info("Graph built with {} entities".format(len(self.graph)))
def retrieve(self, query, k=10):
query_embedding = self.embedding_model.embed_query(query)
semantic_results = self.vector_store.similarity_search(query_embedding, k)
query_entities = self._extract_entities(query)
graph_docs = {}
for entity in query_entities:
if entity in self.graph:
for doc in self.graph[entity]:
doc_id = doc.page_content[:100]
if doc_id not in graph_docs:
graph_docs[doc_id] = (doc, 0.7)
else:
graph_docs[doc_id] = (doc, graph_docs[doc_id][1] + 0.1)
all_docs = {}
for doc, score in semantic_results:
doc_id = doc.page_content[:100]
all_docs[doc_id] = (doc, score)
for doc_id, (doc, score) in graph_docs.items():
if doc_id in all_docs:
all_docs[doc_id] = (doc, all_docs[doc_id][1] * 0.6 + score * 0.4)
else:
all_docs[doc_id] = (doc, score)
sorted_docs = sorted(all_docs.values(), key=lambda x: x[1], reverse=True)
return sorted_docs[:k]
def generate(self, query, context):
prompt = (
"根据以下参考文档(结合知识图谱和语义检索),回答问题。\n\n"
"参考文档:\n{}\n\n"
"问题:{}\n\n"
"请详细回答。".format(context, query)
)
return self._call_llm(prompt)
class BM25RAG(BaseRAG):
def __init__(self, llm_client=None, llm_model="gpt-4o", **kwargs):
self.embedding_model = None
self.vector_store = None
self.llm_client = llm_client
self.llm_model = llm_model
self.name = self.__class__.__name__
from rank_bm25 import BM25Okapi
self.bm25 = None
self._all_texts = []
self._all_documents = []
def index_documents(self, documents):
self._all_documents = documents
self._all_texts = [doc.page_content.split() for doc in documents]
from rank_bm25 import BM25Okapi
self.bm25 = BM25Okapi(self._all_texts)
logger.info("BM25 index built with {} documents".format(len(documents)))
def retrieve(self, query, k=10):
if self.bm25 is None:
return []
query_tokens = query.split()
scores = self.bm25.get_scores(query_tokens)
scored_docs = []
for i, score in enumerate(scores):
if score > 0:
scored_docs.append((self._all_documents[i], float(score)))
scored_docs.sort(key=lambda x: x[1], reverse=True)
return scored_docs[:k]
def generate(self, query, context):
prompt = (
"根据以下参考文档(BM25关键词检索),回答问题。\n\n"
"参考文档:\n{}\n\n"
"问题:{}\n\n"
"请详细回答。".format(context, query)
)
return self._call_llm(prompt)
def query(self, query, k=10):
start_total = time.time()
t0 = time.time()
docs = self.retrieve(query, k)
retrieval_time = time.time() - t0
context = self._format_context(docs)
t1 = time.time()
answer = self.generate(query, context)
generation_time = time.time() - t1
total_time = time.time() - start_total
return RAGResult(
answer=answer,
retrieved_docs=docs,
latency_retrieval=retrieval_time,
latency_generation=generation_time,
latency_total=total_time,
metadata={"method": self.name, "num_context_tokens": len(context)},
)
class TFIDFRAG(BaseRAG):
def __init__(self, llm_client=None, llm_model="gpt-4o", **kwargs):
self.embedding_model = None
self.vector_store = None
self.llm_client = llm_client
self.llm_model = llm_model
self.name = self.__class__.__name__
self.vectorizer = None
self.tfidf_matrix = None
self._all_documents = []
def index_documents(self, documents):
self._all_documents = documents
texts = [doc.page_content for doc in documents]
from sklearn.feature_extraction.text import TfidfVectorizer
self.vectorizer = TfidfVectorizer()
self.tfidf_matrix = self.vectorizer.fit_transform(texts)
logger.info("TF-IDF index built with {} documents, vocab size: {}".format(len(documents), len(self.vectorizer.vocabulary_)))
def retrieve(self, query, k=10):
if self.tfidf_matrix is None:
return []
from sklearn.metrics.pairwise import cosine_similarity
query_vec = self.vectorizer.transform([query])
scores = cosine_similarity(query_vec, self.tfidf_matrix).flatten()
k = min(k, len(scores))
top_indices = np.argsort(scores)[::-1][:k]
results = []
for idx in top_indices:
if scores[idx] > 0:
results.append((self._all_documents[idx], float(scores[idx])))
return results
def generate(self, query, context):
prompt = (
"根据以下参考文档(TF-IDF关键词检索),回答问题。\n\n"
"参考文档:\n{}\n\n"
"问题:{}\n\n"
"请详细回答。".format(context, query)
)
return self._call_llm(prompt)
def query(self, query, k=10):
start_total = time.time()
t0 = time.time()
docs = self.retrieve(query, k)
retrieval_time = time.time() - t0
context = self._format_context(docs)
t1 = time.time()
answer = self.generate(query, context)
generation_time = time.time() - t1
total_time = time.time() - start_total
return RAGResult(
answer=answer,
retrieved_docs=docs,
latency_retrieval=retrieval_time,
latency_generation=generation_time,
latency_total=total_time,
metadata={"method": self.name, "num_context_tokens": len(context)},
)
class KeywordRAG(BaseRAG):
def __init__(self, search_method="bm25", llm_client=None, llm_model="gpt-4o", **kwargs):
self.embedding_model = None
self.vector_store = None
self.llm_client = llm_client
self.llm_model = llm_model
self.name = self.__class__.__name__
self.search_method = search_method
self._all_documents = []
self._all_texts = []
self.bm25 = None
self.tfidf_matrix = None
self.vectorizer = None
def index_documents(self, documents):
self._all_documents = documents
self._all_texts = [doc.page_content for doc in documents]
if self.search_method == "bm25":
from rank_bm25 import BM25Okapi
tokenized = [t.split() for t in self._all_texts]
self.bm25 = BM25Okapi(tokenized)
elif self.search_method == "tfidf":
from sklearn.feature_extraction.text import TfidfVectorizer
self.vectorizer = TfidfVectorizer()
self.tfidf_matrix = self.vectorizer.fit_transform(self._all_texts)
logger.info("KeywordRAG ({}) index built with {} documents".format(self.search_method, len(documents)))
def retrieve(self, query, k=10):
if self.search_method == "bm25":
query_tokens = query.split()
scores = self.bm25.get_scores(query_tokens)
elif self.search_method == "tfidf":
from sklearn.metrics.pairwise import cosine_similarity
query_vec = self.vectorizer.transform([query])
scores = cosine_similarity(query_vec, self.tfidf_matrix).flatten()
else:
return []
k = min(k, len(scores))
top_indices = np.argsort(scores)[::-1][:k]
results = []
for idx in top_indices:
if scores[idx] > 0:
results.append((self._all_documents[idx], float(scores[idx])))
return results
def generate(self, query, context):
prompt = (
"根据以下参考文档(关键词检索:{}),回答问题。\n\n"
"参考文档:\n{}\n\n"
"问题:{}\n\n"
"请详细回答。".format(self.search_method, context, query)
)
return self._call_llm(prompt)
def query(self, query, k=10):
start_total = time.time()
t0 = time.time()
docs = self.retrieve(query, k)
retrieval_time = time.time() - t0
context = self._format_context(docs)
t1 = time.time()
answer = self.generate(query, context)
generation_time = time.time() - t1
total_time = time.time() - start_total
return RAGResult(
answer=answer,
retrieved_docs=docs,
latency_retrieval=retrieval_time,
latency_generation=generation_time,
latency_total=total_time,
metadata={"method": self.name, "num_context_tokens": len(context)},
)