|
|
@@ -0,0 +1,1377 @@
|
|
|
+from abc import ABC, abstractmethod
|
|
|
+from typing import List, Dict, Any, Optional, Tuple
|
|
|
+import time
|
|
|
+import re
|
|
|
+import numpy as np
|
|
|
+from loguru import logger
|
|
|
+
|
|
|
+from .document_processor import Document
|
|
|
+from .embedding_models import BaseEmbedding
|
|
|
+from .vector_stores import BaseVectorStore
|
|
|
+
|
|
|
+
|
|
|
+class RAGResult(object):
|
|
|
+ def __init__(self, answer="", retrieved_docs=None, latency_retrieval=0.0,
|
|
|
+ latency_generation=0.0, latency_total=0.0, metadata=None):
|
|
|
+ self.answer = answer
|
|
|
+ self.retrieved_docs = retrieved_docs if retrieved_docs is not None else []
|
|
|
+ self.latency_retrieval = latency_retrieval
|
|
|
+ self.latency_generation = latency_generation
|
|
|
+ self.latency_total = latency_total
|
|
|
+ self.metadata = metadata if metadata is not None else {}
|
|
|
+
|
|
|
+ def to_dict(self):
|
|
|
+ return {
|
|
|
+ "answer": self.answer,
|
|
|
+ "num_docs_retrieved": len(self.retrieved_docs),
|
|
|
+ "latency_retrieval": round(self.latency_retrieval, 3),
|
|
|
+ "latency_generation": round(self.latency_generation, 3),
|
|
|
+ "latency_total": round(self.latency_total, 3),
|
|
|
+ "metadata": self.metadata or {},
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+class BaseRAG(ABC):
|
|
|
+ def __init__(self, embedding_model=None, vector_store=None, llm_client=None, llm_model="gpt-4o", **kwargs):
|
|
|
+ self.embedding_model = embedding_model
|
|
|
+ self.vector_store = vector_store
|
|
|
+ self.llm_client = llm_client
|
|
|
+ self.llm_model = llm_model
|
|
|
+ self.name = self.__class__.__name__
|
|
|
+
|
|
|
+ def index_documents(self, documents):
|
|
|
+ texts = [doc.page_content for doc in documents]
|
|
|
+ embeddings = self.embedding_model.embed_documents(texts)
|
|
|
+ self.vector_store.add_documents(documents, embeddings)
|
|
|
+
|
|
|
+ def _call_llm(self, prompt, system_prompt=None):
|
|
|
+ messages = []
|
|
|
+ if system_prompt:
|
|
|
+ messages.append({"role": "system", "content": system_prompt})
|
|
|
+ messages.append({"role": "user", "content": prompt})
|
|
|
+
|
|
|
+ response = self.llm_client.chat.completions.create(
|
|
|
+ model=self.llm_model,
|
|
|
+ messages=messages,
|
|
|
+ temperature=0.1,
|
|
|
+ max_tokens=2048,
|
|
|
+ )
|
|
|
+ return response.choices[0].message.content
|
|
|
+
|
|
|
+ def _format_context(self, docs):
|
|
|
+ context_parts = []
|
|
|
+ for i, (doc, score) in enumerate(docs, 1):
|
|
|
+ source = doc.metadata.get("source", "unknown")
|
|
|
+ context_parts.append("[{}] (Score: {:.3f}, Source: {})\n{}".format(i, score, source, doc.page_content))
|
|
|
+ return "\n\n---\n\n".join(context_parts)
|
|
|
+
|
|
|
+ @abstractmethod
|
|
|
+ def retrieve(self, query, k=10):
|
|
|
+ pass
|
|
|
+
|
|
|
+ @abstractmethod
|
|
|
+ def generate(self, query, context):
|
|
|
+ pass
|
|
|
+
|
|
|
+ def query(self, query, k=10):
|
|
|
+ start_total = time.time()
|
|
|
+
|
|
|
+ t0 = time.time()
|
|
|
+ docs = self.retrieve(query, k)
|
|
|
+ retrieval_time = time.time() - t0
|
|
|
+
|
|
|
+ context = self._format_context(docs)
|
|
|
+
|
|
|
+ t1 = time.time()
|
|
|
+ answer = self.generate(query, context)
|
|
|
+ generation_time = time.time() - t1
|
|
|
+
|
|
|
+ total_time = time.time() - start_total
|
|
|
+
|
|
|
+ return RAGResult(
|
|
|
+ answer=answer,
|
|
|
+ retrieved_docs=docs,
|
|
|
+ latency_retrieval=retrieval_time,
|
|
|
+ latency_generation=generation_time,
|
|
|
+ latency_total=total_time,
|
|
|
+ metadata={"method": self.name, "num_context_tokens": len(context)},
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+class NaiveRAG(BaseRAG):
|
|
|
+ def __init__(self, retrieval_prompt_template=None, **kwargs):
|
|
|
+ super().__init__(**kwargs)
|
|
|
+ self.retrieval_prompt_template = retrieval_prompt_template or (
|
|
|
+ "根据以下参考文档,回答问题。\n\n"
|
|
|
+ "参考文档:\n{context}\n\n"
|
|
|
+ "问题:{query}\n\n"
|
|
|
+ "请详细回答,如果参考文档中没有相关信息,请说明无法从文档中找到答案。"
|
|
|
+ )
|
|
|
+
|
|
|
+ def retrieve(self, query, k=10):
|
|
|
+ query_embedding = self.embedding_model.embed_query(query)
|
|
|
+ return self.vector_store.similarity_search(query_embedding, k)
|
|
|
+
|
|
|
+ def generate(self, query, context):
|
|
|
+ prompt = self.retrieval_prompt_template.format(context=context, query=query)
|
|
|
+ return self._call_llm(prompt)
|
|
|
+
|
|
|
+
|
|
|
+class RerankRAG(BaseRAG):
|
|
|
+ def __init__(self, rerank_model=None, rerank_top_k=5, **kwargs):
|
|
|
+ super().__init__(**kwargs)
|
|
|
+ self.rerank_model = rerank_model
|
|
|
+ self.rerank_top_k = rerank_top_k
|
|
|
+ self.initial_k = 20
|
|
|
+
|
|
|
+ def retrieve(self, query, k=10):
|
|
|
+ query_embedding = self.embedding_model.embed_query(query)
|
|
|
+ initial_docs = self.vector_store.similarity_search(query_embedding, self.initial_k)
|
|
|
+
|
|
|
+ if self.rerank_model and len(initial_docs) > 0:
|
|
|
+ texts = [doc.page_content for doc, _ in initial_docs]
|
|
|
+ pairs = [(query, text) for text in texts]
|
|
|
+ scores = self.rerank_model.compute_score(pairs)
|
|
|
+
|
|
|
+ if isinstance(scores, (int, float)):
|
|
|
+ scores = [scores]
|
|
|
+
|
|
|
+ reranked = list(zip(initial_docs, scores))
|
|
|
+ reranked.sort(key=lambda x: x[1], reverse=True)
|
|
|
+ return [(doc, float(score)) for (doc, _), score in reranked[:k]]
|
|
|
+
|
|
|
+ return initial_docs[:k]
|
|
|
+
|
|
|
+ def generate(self, query, context):
|
|
|
+ prompt = (
|
|
|
+ "根据以下经过重排序的参考文档,回答问题。\n\n"
|
|
|
+ "参考文档:\n{}\n\n"
|
|
|
+ "问题:{}\n\n"
|
|
|
+ "请详细回答。".format(context, query)
|
|
|
+ )
|
|
|
+ return self._call_llm(prompt)
|
|
|
+
|
|
|
+
|
|
|
+class ParentDocumentRAG(BaseRAG):
|
|
|
+ def __init__(self, parent_chunk_size=1500, **kwargs):
|
|
|
+ super().__init__(**kwargs)
|
|
|
+ self.parent_chunk_size = parent_chunk_size
|
|
|
+ self.parent_docs = []
|
|
|
+ self.child_to_parent = {}
|
|
|
+
|
|
|
+ def index_documents(self, documents):
|
|
|
+ self.parent_docs = []
|
|
|
+ self.child_to_parent = {}
|
|
|
+
|
|
|
+ for i, doc in enumerate(documents):
|
|
|
+ words = doc.page_content.split()
|
|
|
+ for j in range(0, len(words), self.parent_chunk_size):
|
|
|
+ parent_text = " ".join(words[j:j + self.parent_chunk_size])
|
|
|
+ parent_doc = Document(
|
|
|
+ page_content=parent_text,
|
|
|
+ metadata=dict(doc.metadata, chunk_index=j, is_parent=True)
|
|
|
+ )
|
|
|
+ self.parent_docs.append(parent_doc)
|
|
|
+
|
|
|
+ child_texts = []
|
|
|
+ for doc in documents:
|
|
|
+ words = doc.page_content.split()
|
|
|
+ for j in range(0, len(words), self.parent_chunk_size):
|
|
|
+ chunk_words = words[j:j + self.parent_chunk_size]
|
|
|
+ for k in range(0, len(chunk_words), 512):
|
|
|
+ child_text = " ".join(chunk_words[k:k + 512])
|
|
|
+ child_id = len(child_texts)
|
|
|
+ child_texts.append(child_text)
|
|
|
+ self.child_to_parent[child_id] = child_id // 3
|
|
|
+
|
|
|
+ if child_texts:
|
|
|
+ embeddings = self.embedding_model.embed_documents(child_texts)
|
|
|
+ child_docs = [
|
|
|
+ Document(page_content=text, metadata={"is_parent": False})
|
|
|
+ for text in child_texts
|
|
|
+ ]
|
|
|
+ self.vector_store.add_documents(child_docs, embeddings)
|
|
|
+
|
|
|
+ logger.info("ParentDocumentRAG: {} parents, {} children".format(len(self.parent_docs), len(child_texts)))
|
|
|
+
|
|
|
+ def retrieve(self, query, k=10):
|
|
|
+ query_embedding = self.embedding_model.embed_query(query)
|
|
|
+ child_results = self.vector_store.similarity_search(query_embedding, k * 2)
|
|
|
+
|
|
|
+ parent_map = {}
|
|
|
+ for child_doc, score in child_results:
|
|
|
+ for i, parent_doc in enumerate(self.parent_docs):
|
|
|
+ if child_doc.page_content[:50] in parent_doc.page_content:
|
|
|
+ if i not in parent_map or score > parent_map[i]:
|
|
|
+ parent_map[i] = (parent_doc, score)
|
|
|
+ break
|
|
|
+
|
|
|
+ parent_results = sorted(parent_map.values(), key=lambda x: x[1], reverse=True)
|
|
|
+ return parent_results[:k]
|
|
|
+
|
|
|
+ def generate(self, query, context):
|
|
|
+ prompt = (
|
|
|
+ "根据以下参考文档(父子文档检索,包含完整上下文),回答问题。\n\n"
|
|
|
+ "参考文档:\n{}\n\n"
|
|
|
+ "问题:{}\n\n"
|
|
|
+ "请详细回答。".format(context, query)
|
|
|
+ )
|
|
|
+ return self._call_llm(prompt)
|
|
|
+
|
|
|
+
|
|
|
+class LLMFilterRAG(BaseRAG):
|
|
|
+ def __init__(self, filter_threshold=0.5, **kwargs):
|
|
|
+ super().__init__(**kwargs)
|
|
|
+ self.filter_threshold = filter_threshold
|
|
|
+
|
|
|
+ def _score_relevance(self, query, doc):
|
|
|
+ prompt = (
|
|
|
+ "评估以下文档与问题的相关性,给出0-1之间的分数。只返回分数数字。\n\n"
|
|
|
+ "问题:{}\n\n"
|
|
|
+ "文档:{}\n\n"
|
|
|
+ "相关性分数:".format(query, doc.page_content[:300])
|
|
|
+ )
|
|
|
+ try:
|
|
|
+ response = self._call_llm(prompt).strip()
|
|
|
+ return float(response)
|
|
|
+ except:
|
|
|
+ return 0.5
|
|
|
+
|
|
|
+ def retrieve(self, query, k=10):
|
|
|
+ query_embedding = self.embedding_model.embed_query(query)
|
|
|
+ initial_results = self.vector_store.similarity_search(query_embedding, k * 3)
|
|
|
+
|
|
|
+ filtered_results = []
|
|
|
+ for doc, score in initial_results:
|
|
|
+ relevance = self._score_relevance(query, doc)
|
|
|
+ if relevance >= self.filter_threshold:
|
|
|
+ filtered_results.append((doc, score * relevance))
|
|
|
+
|
|
|
+ filtered_results.sort(key=lambda x: x[1], reverse=True)
|
|
|
+ return filtered_results[:k]
|
|
|
+
|
|
|
+ def generate(self, query, context):
|
|
|
+ prompt = (
|
|
|
+ "根据以下经过LLM精选的参考文档,回答问题。\n\n"
|
|
|
+ "参考文档:\n{}\n\n"
|
|
|
+ "问题:{}\n\n"
|
|
|
+ "请详细回答。".format(context, query)
|
|
|
+ )
|
|
|
+ return self._call_llm(prompt)
|
|
|
+
|
|
|
+
|
|
|
+class QueryRoutingRAG(BaseRAG):
|
|
|
+ def __init__(self, **kwargs):
|
|
|
+ super().__init__(**kwargs)
|
|
|
+
|
|
|
+ def _classify_query(self, query):
|
|
|
+ categories = [
|
|
|
+ "budget", "deadline", "qualification", "evaluation",
|
|
|
+ "payment", "warranty", "delivery", "contact", "scope"
|
|
|
+ ]
|
|
|
+ categories_str = ", ".join(categories)
|
|
|
+ prompt = (
|
|
|
+ "将以下问题分类到以下类别之一:{}\n\n"
|
|
|
+ "问题:{}\n\n"
|
|
|
+ "类别:".format(categories_str, query)
|
|
|
+ )
|
|
|
+ return self._call_llm(prompt).strip().lower()
|
|
|
+
|
|
|
+ def retrieve(self, query, k=10):
|
|
|
+ query_embedding = self.embedding_model.embed_query(query)
|
|
|
+ semantic_results = self.vector_store.similarity_search(query_embedding, k)
|
|
|
+
|
|
|
+ category = self._classify_query(query)
|
|
|
+ logger.info("QueryRouting: category={}".format(category))
|
|
|
+
|
|
|
+ category_boost = {}
|
|
|
+ for doc, score in semantic_results:
|
|
|
+ content_lower = doc.page_content.lower()
|
|
|
+ category_keywords = {
|
|
|
+ "budget": ["预算", "金额", "价格", "费用", "报价"],
|
|
|
+ "deadline": ["截止", "时间", "日期", "开标"],
|
|
|
+ "qualification": ["资格", "要求", "证书", "业绩"],
|
|
|
+ "evaluation": ["评标", "评价", "分数", "方法"],
|
|
|
+ "payment": ["付款", "结算", "进度", "保证金"],
|
|
|
+ "warranty": ["质保", "维修", "售后", "服务"],
|
|
|
+ "delivery": ["交货", "工期", "交付", "地点"],
|
|
|
+ "contact": ["联系人", "电话", "邮箱"],
|
|
|
+ "scope": ["范围", "内容", "清单", "设备"],
|
|
|
+ }
|
|
|
+ keywords = category_keywords.get(category, [])
|
|
|
+ boost = sum(1 for kw in keywords if kw in content_lower) * 0.1
|
|
|
+ category_boost[id(doc)] = boost
|
|
|
+
|
|
|
+ enhanced_results = []
|
|
|
+ for doc, score in semantic_results:
|
|
|
+ boost = category_boost.get(id(doc), 0)
|
|
|
+ enhanced_results.append((doc, score + boost))
|
|
|
+
|
|
|
+ enhanced_results.sort(key=lambda x: x[1], reverse=True)
|
|
|
+ return enhanced_results[:k]
|
|
|
+
|
|
|
+ def generate(self, query, context):
|
|
|
+ prompt = (
|
|
|
+ "根据以下经过查询路由的参考文档,回答问题。\n\n"
|
|
|
+ "参考文档:\n{}\n\n"
|
|
|
+ "问题:{}\n\n"
|
|
|
+ "请详细回答。".format(context, query)
|
|
|
+ )
|
|
|
+ return self._call_llm(prompt)
|
|
|
+
|
|
|
+
|
|
|
+class MetadataFilterRAG(BaseRAG):
|
|
|
+ def __init__(self, **kwargs):
|
|
|
+ super().__init__(**kwargs)
|
|
|
+
|
|
|
+ def retrieve(self, query, k=10, metadata_filter=None):
|
|
|
+ query_embedding = self.embedding_model.embed_query(query)
|
|
|
+ all_results = self.vector_store.similarity_search(query_embedding, k * 3)
|
|
|
+
|
|
|
+ if metadata_filter:
|
|
|
+ filtered = []
|
|
|
+ for doc, score in all_results:
|
|
|
+ match = all(
|
|
|
+ doc.metadata.get(key) == value
|
|
|
+ for key, value in metadata_filter.items()
|
|
|
+ )
|
|
|
+ if match:
|
|
|
+ filtered.append((doc, score))
|
|
|
+ return filtered[:k]
|
|
|
+
|
|
|
+ return all_results[:k]
|
|
|
+
|
|
|
+ def generate(self, query, context):
|
|
|
+ prompt = (
|
|
|
+ "根据以下经过元数据筛选的参考文档,回答问题。\n\n"
|
|
|
+ "参考文档:\n{}\n\n"
|
|
|
+ "问题:{}\n\n"
|
|
|
+ "请详细回答。".format(context, query)
|
|
|
+ )
|
|
|
+ return self._call_llm(prompt)
|
|
|
+
|
|
|
+
|
|
|
+class AdaptiveRAG(BaseRAG):
|
|
|
+ def __init__(self, **kwargs):
|
|
|
+ super().__init__(**kwargs)
|
|
|
+
|
|
|
+ def _determine_strategy(self, query):
|
|
|
+ prompt = (
|
|
|
+ "根据问题类型,选择最合适的检索策略。\n"
|
|
|
+ "可选策略:\n"
|
|
|
+ "- semantic: 语义检索,适合概念性问题\n"
|
|
|
+ "- keyword: 关键词检索,适合精确实体匹配\n"
|
|
|
+ "- multi: 多路检索,适合复杂综合问题\n\n"
|
|
|
+ "问题:{}\n\n"
|
|
|
+ "策略名称(只返回name):".format(query)
|
|
|
+ )
|
|
|
+ return self._call_llm(prompt).strip().lower()
|
|
|
+
|
|
|
+ def retrieve(self, query, k=10):
|
|
|
+ strategy = self._determine_strategy(query)
|
|
|
+ logger.info("AdaptiveRAG: strategy={}".format(strategy))
|
|
|
+
|
|
|
+ if strategy in ["keyword", "exact"]:
|
|
|
+ try:
|
|
|
+ from rank_bm25 import BM25Okapi
|
|
|
+ except:
|
|
|
+ BM25Okapi = None
|
|
|
+
|
|
|
+ if BM25Okapi:
|
|
|
+ all_docs = self.vector_store.documents if hasattr(self.vector_store, 'documents') else []
|
|
|
+ if all_docs:
|
|
|
+ texts = [doc.page_content.split() for doc in all_docs]
|
|
|
+ bm25 = BM25Okapi(texts)
|
|
|
+ scores = bm25.get_scores(query.split())
|
|
|
+ results = []
|
|
|
+ for i, score in enumerate(scores):
|
|
|
+ if score > 0:
|
|
|
+ results.append((all_docs[i], float(score)))
|
|
|
+ results.sort(key=lambda x: x[1], reverse=True)
|
|
|
+ return results[:k]
|
|
|
+
|
|
|
+ query_embedding = self.embedding_model.embed_query(query)
|
|
|
+ semantic_results = self.vector_store.similarity_search(query_embedding, k)
|
|
|
+
|
|
|
+ if strategy in ["multi", "comprehensive"]:
|
|
|
+ all_results = {}
|
|
|
+ for doc, score in semantic_results:
|
|
|
+ doc_id = doc.page_content[:100]
|
|
|
+ all_results[doc_id] = (doc, score)
|
|
|
+
|
|
|
+ try:
|
|
|
+ from rank_bm25 import BM25Okapi
|
|
|
+ all_docs = self.vector_store.documents if hasattr(self.vector_store, 'documents') else []
|
|
|
+ if all_docs:
|
|
|
+ texts = [doc.page_content.split() for doc in all_docs]
|
|
|
+ bm25 = BM25Okapi(texts)
|
|
|
+ bm25_scores = bm25.get_scores(query.split())
|
|
|
+ max_bm25 = max(bm25_scores) if max(bm25_scores) > 0 else 1.0
|
|
|
+ for i, doc in enumerate(all_docs):
|
|
|
+ doc_id = doc.page_content[:100]
|
|
|
+ norm_bm25 = bm25_scores[i] / max_bm25
|
|
|
+ if doc_id in all_results:
|
|
|
+ all_results[doc_id] = (doc, all_results[doc_id][1] * 0.6 + norm_bm25 * 0.4)
|
|
|
+ else:
|
|
|
+ all_results[doc_id] = (doc, norm_bm25)
|
|
|
+ except:
|
|
|
+ pass
|
|
|
+
|
|
|
+ sorted_results = sorted(all_results.values(), key=lambda x: x[1], reverse=True)
|
|
|
+ return sorted_results[:k]
|
|
|
+
|
|
|
+ return semantic_results[:k]
|
|
|
+
|
|
|
+ def generate(self, query, context):
|
|
|
+ prompt = (
|
|
|
+ "根据以下经过自适应检索的参考文档,回答问题。\n\n"
|
|
|
+ "参考文档:\n{}\n\n"
|
|
|
+ "问题:{}\n\n"
|
|
|
+ "请详细回答。".format(context, query)
|
|
|
+ )
|
|
|
+ return self._call_llm(prompt)
|
|
|
+
|
|
|
+
|
|
|
+class HybridSearchRAG(BaseRAG):
|
|
|
+ def __init__(self, bm25_index=None, semantic_weight=0.5, **kwargs):
|
|
|
+ super().__init__(**kwargs)
|
|
|
+ self.bm25_index = bm25_index
|
|
|
+ self.semantic_weight = semantic_weight
|
|
|
+ self._all_texts = []
|
|
|
+
|
|
|
+ def index_documents(self, documents):
|
|
|
+ super().index_documents(documents)
|
|
|
+ if self.bm25_index is not None:
|
|
|
+ from rank_bm25 import BM25Okapi
|
|
|
+ self._all_texts = [doc.page_content.split() for doc in documents]
|
|
|
+ self.bm25_index = BM25Okapi(self._all_texts)
|
|
|
+
|
|
|
+ def retrieve(self, query, k=10):
|
|
|
+ query_embedding = self.embedding_model.embed_query(query)
|
|
|
+ semantic_results = self.vector_store.similarity_search(query_embedding, k * 2)
|
|
|
+
|
|
|
+ if self.bm25_index and self._all_texts:
|
|
|
+ from rank_bm25 import BM25Okapi
|
|
|
+ query_tokens = query.split()
|
|
|
+ bm25_scores = self.bm25_index.get_scores(query_tokens)
|
|
|
+
|
|
|
+ max_bm25 = max(bm25_scores) if bm25_scores.max() > 0 else 1.0
|
|
|
+ normalized_bm25 = bm25_scores / max_bm25
|
|
|
+
|
|
|
+ combined = []
|
|
|
+ for i, (doc, sem_score) in enumerate(semantic_results):
|
|
|
+ bm25_score = normalized_bm25[i] if i < len(normalized_bm25) else 0.0
|
|
|
+ combined_score = (
|
|
|
+ self.semantic_weight * sem_score +
|
|
|
+ (1 - self.semantic_weight) * bm25_score
|
|
|
+ )
|
|
|
+ combined.append((doc, combined_score))
|
|
|
+
|
|
|
+ combined.sort(key=lambda x: x[1], reverse=True)
|
|
|
+ return combined[:k]
|
|
|
+
|
|
|
+ return semantic_results[:k]
|
|
|
+
|
|
|
+ def generate(self, query, context):
|
|
|
+ prompt = (
|
|
|
+ "根据以下参考文档(结合语义搜索和BM25关键词搜索),回答问题。\n\n"
|
|
|
+ "参考文档:\n{}\n\n"
|
|
|
+ "问题:{}\n\n"
|
|
|
+ "请详细回答。".format(context, query)
|
|
|
+ )
|
|
|
+ return self._call_llm(prompt)
|
|
|
+
|
|
|
+
|
|
|
+class MultiQueryRAG(BaseRAG):
|
|
|
+ def __init__(self, num_queries=3, **kwargs):
|
|
|
+ super().__init__(**kwargs)
|
|
|
+ self.num_queries = num_queries
|
|
|
+
|
|
|
+ def _generate_queries(self, query):
|
|
|
+ prompt = (
|
|
|
+ "请将以下问题扩展为{}个不同角度的问题,"
|
|
|
+ "用于检索更全面的信息。每行一个问题。\n\n"
|
|
|
+ "原问题:{}\n\n"
|
|
|
+ "扩展问题:".format(self.num_queries, query)
|
|
|
+ )
|
|
|
+ response = self._call_llm(prompt)
|
|
|
+ queries = [q.strip() for q in response.strip().split("\n") if q.strip()]
|
|
|
+ queries.insert(0, query)
|
|
|
+ return queries[:self.num_queries + 1]
|
|
|
+
|
|
|
+ def retrieve(self, query, k=10):
|
|
|
+ queries = self._generate_queries(query)
|
|
|
+ all_docs = {}
|
|
|
+
|
|
|
+ for q in queries:
|
|
|
+ q_embedding = self.embedding_model.embed_query(q)
|
|
|
+ results = self.vector_store.similarity_search(q_embedding, k)
|
|
|
+ for doc, score in results:
|
|
|
+ doc_id = doc.page_content[:100]
|
|
|
+ if doc_id not in all_docs or score > all_docs[doc_id][1]:
|
|
|
+ all_docs[doc_id] = (doc, score)
|
|
|
+
|
|
|
+ sorted_docs = sorted(all_docs.values(), key=lambda x: x[1], reverse=True)
|
|
|
+ return sorted_docs[:k]
|
|
|
+
|
|
|
+ def generate(self, query, context):
|
|
|
+ prompt = (
|
|
|
+ "根据以下参考文档,回答问题。\n\n"
|
|
|
+ "参考文档:\n{}\n\n"
|
|
|
+ "问题:{}\n\n"
|
|
|
+ "请综合多个查询角度的检索结果,详细回答。".format(context, query)
|
|
|
+ )
|
|
|
+ return self._call_llm(prompt)
|
|
|
+
|
|
|
+
|
|
|
+class HyDERAG(BaseRAG):
|
|
|
+ def __init__(self, num_hypotheses=3, **kwargs):
|
|
|
+ super().__init__(**kwargs)
|
|
|
+ self.num_hypotheses = num_hypotheses
|
|
|
+
|
|
|
+ def _generate_hypothetical_docs(self, query):
|
|
|
+ prompt = (
|
|
|
+ "假设你是一个招投标专家,请根据以下问题,"
|
|
|
+ "生成{}个可能包含答案的假想文档段落。"
|
|
|
+ "每个段落用'<doc>'和'</doc>'分隔。\n\n"
|
|
|
+ "问题:{}\n\n"
|
|
|
+ "假想文档:".format(self.num_hypotheses, query)
|
|
|
+ )
|
|
|
+ response = self._call_llm(prompt)
|
|
|
+ docs = re.findall(r'<doc>(.*?)</doc>', response, re.DOTALL)
|
|
|
+ if not docs:
|
|
|
+ docs = [response]
|
|
|
+ return docs[:self.num_hypotheses]
|
|
|
+
|
|
|
+ def retrieve(self, query, k=10):
|
|
|
+ hypothetical_docs = self._generate_hypothetical_docs(query)
|
|
|
+ all_results = {}
|
|
|
+
|
|
|
+ for hypo_doc in hypothetical_docs:
|
|
|
+ hypo_embedding = self.embedding_model.embed_query(hypo_doc)
|
|
|
+ results = self.vector_store.similarity_search(hypo_embedding, k)
|
|
|
+ for doc, score in results:
|
|
|
+ doc_id = doc.page_content[:100]
|
|
|
+ if doc_id not in all_results or score > all_results[doc_id][1]:
|
|
|
+ all_results[doc_id] = (doc, score)
|
|
|
+
|
|
|
+ query_embedding = self.embedding_model.embed_query(query)
|
|
|
+ direct_results = self.vector_store.similarity_search(query_embedding, k)
|
|
|
+ for doc, score in direct_results:
|
|
|
+ doc_id = doc.page_content[:100]
|
|
|
+ if doc_id not in all_results or score > all_results[doc_id][1]:
|
|
|
+ all_results[doc_id] = (doc, score)
|
|
|
+
|
|
|
+ sorted_docs = sorted(all_results.values(), key=lambda x: x[1], reverse=True)
|
|
|
+ return sorted_docs[:k]
|
|
|
+
|
|
|
+ def generate(self, query, context):
|
|
|
+ prompt = (
|
|
|
+ "根据以下参考文档,回答问题。\n\n"
|
|
|
+ "参考文档:\n{}\n\n"
|
|
|
+ "问题:{}\n\n"
|
|
|
+ "请详细回答。".format(context, query)
|
|
|
+ )
|
|
|
+ return self._call_llm(prompt)
|
|
|
+
|
|
|
+
|
|
|
+class StepBackRAG(BaseRAG):
|
|
|
+ def __init__(self, **kwargs):
|
|
|
+ super().__init__(**kwargs)
|
|
|
+
|
|
|
+ def _generate_step_back_query(self, query):
|
|
|
+ prompt = (
|
|
|
+ "请将以下具体问题抽象为一个更通用的高层次问题,"
|
|
|
+ "用于查找相关背景知识。只返回抽象后的问题。\n\n"
|
|
|
+ "具体问题:{}\n\n"
|
|
|
+ "高层次问题:".format(query)
|
|
|
+ )
|
|
|
+ return self._call_llm(prompt).strip()
|
|
|
+
|
|
|
+ def retrieve(self, query, k=10):
|
|
|
+ step_back_query = self._generate_step_back_query(query)
|
|
|
+ logger.info("StepBack query: {}".format(step_back_query))
|
|
|
+
|
|
|
+ query_embedding = self.embedding_model.embed_query(query)
|
|
|
+ direct_results = self.vector_store.similarity_search(query_embedding, k // 2)
|
|
|
+
|
|
|
+ step_back_embedding = self.embedding_model.embed_query(step_back_query)
|
|
|
+ step_back_results = self.vector_store.similarity_search(step_back_embedding, k // 2)
|
|
|
+
|
|
|
+ all_docs = {}
|
|
|
+ for doc, score in direct_results + step_back_results:
|
|
|
+ doc_id = doc.page_content[:100]
|
|
|
+ if doc_id not in all_docs or score > all_docs[doc_id][1]:
|
|
|
+ all_docs[doc_id] = (doc, score)
|
|
|
+
|
|
|
+ sorted_docs = sorted(all_docs.values(), key=lambda x: x[1], reverse=True)
|
|
|
+ return sorted_docs[:k]
|
|
|
+
|
|
|
+ def generate(self, query, context):
|
|
|
+ prompt = (
|
|
|
+ "结合以下参考文档(包含具体问题和抽象问题的检索结果),"
|
|
|
+ "回答问题。\n\n"
|
|
|
+ "参考文档:\n{}\n\n"
|
|
|
+ "问题:{}\n\n"
|
|
|
+ "请详细回答。".format(context, query)
|
|
|
+ )
|
|
|
+ return self._call_llm(prompt)
|
|
|
+
|
|
|
+
|
|
|
+class ContextualCompressionRAG(BaseRAG):
|
|
|
+ def __init__(self, compression_llm=None, **kwargs):
|
|
|
+ super().__init__(**kwargs)
|
|
|
+ self.compression_llm = compression_llm or llm_client
|
|
|
+
|
|
|
+ def _compress_doc(self, query, doc):
|
|
|
+ prompt = (
|
|
|
+ "请压缩以下文档片段,仅保留与问题相关的信息。\n\n"
|
|
|
+ "问题:{}\n\n"
|
|
|
+ "文档:{}\n\n"
|
|
|
+ "压缩后的内容:".format(query, doc.page_content)
|
|
|
+ )
|
|
|
+ messages = []
|
|
|
+ messages.append({"role": "user", "content": prompt})
|
|
|
+ response = self.compression_llm.chat.completions.create(
|
|
|
+ model=self.llm_model,
|
|
|
+ messages=messages,
|
|
|
+ temperature=0.0,
|
|
|
+ max_tokens=512,
|
|
|
+ )
|
|
|
+ return response.choices[0].message.content
|
|
|
+
|
|
|
+ def retrieve(self, query, k=10):
|
|
|
+ query_embedding = self.embedding_model.embed_query(query)
|
|
|
+ initial_results = self.vector_store.similarity_search(query_embedding, k * 2)
|
|
|
+
|
|
|
+ compressed_results = []
|
|
|
+ for doc, score in initial_results:
|
|
|
+ compressed_text = self._compress_doc(query, doc)
|
|
|
+ compressed_doc = Document(page_content=compressed_text, metadata=doc.metadata)
|
|
|
+ compressed_results.append((compressed_doc, score))
|
|
|
+
|
|
|
+ return compressed_results[:k]
|
|
|
+
|
|
|
+ def generate(self, query, context):
|
|
|
+ prompt = (
|
|
|
+ "根据以下经过上下文压缩的参考文档,回答问题。\n\n"
|
|
|
+ "参考文档:\n{}\n\n"
|
|
|
+ "问题:{}\n\n"
|
|
|
+ "请详细回答。".format(context, query)
|
|
|
+ )
|
|
|
+ return self._call_llm(prompt)
|
|
|
+
|
|
|
+
|
|
|
+class SelfRAG(BaseRAG):
|
|
|
+ def __init__(self, relevance_threshold=0.5, support_threshold=0.5, retrieval_threshold=0.6, critic_model=None, **kwargs):
|
|
|
+ super().__init__(**kwargs)
|
|
|
+ self.relevance_threshold = relevance_threshold
|
|
|
+ self.support_threshold = support_threshold
|
|
|
+ self.retrieval_threshold = retrieval_threshold
|
|
|
+ self.critic_model = critic_model
|
|
|
+
|
|
|
+ def _is_retrieval_needed(self, query):
|
|
|
+ prompt = (
|
|
|
+ "判断以下问题是否需要检索外部知识才能回答。"
|
|
|
+ "只需回答'是'或'否'。\n\n"
|
|
|
+ "问题:{}\n\n"
|
|
|
+ "是否需要检索:".format(query)
|
|
|
+ )
|
|
|
+ response = self._call_llm(prompt).strip().lower()
|
|
|
+ return "是" in response or "yes" in response
|
|
|
+
|
|
|
+ def _evaluate_relevance(self, query, doc):
|
|
|
+ prompt = (
|
|
|
+ "评估以下文档片段与问题的相关性,给出0-1之间的分数。\n\n"
|
|
|
+ "问题:{}\n\n"
|
|
|
+ "文档:{}\n\n"
|
|
|
+ "相关性分数:".format(query, doc.page_content)
|
|
|
+ )
|
|
|
+ response = self._call_llm(prompt).strip()
|
|
|
+ try:
|
|
|
+ return float(response)
|
|
|
+ except:
|
|
|
+ return 0.5
|
|
|
+
|
|
|
+ def _evaluate_support(self, query, answer, doc):
|
|
|
+ prompt = (
|
|
|
+ "评估以下回答是否得到了文档的支持,给出0-1之间的分数。\n\n"
|
|
|
+ "文档:{}\n\n"
|
|
|
+ "回答:{}\n\n"
|
|
|
+ "支持分数:".format(doc.page_content, answer)
|
|
|
+ )
|
|
|
+ response = self._call_llm(prompt).strip()
|
|
|
+ try:
|
|
|
+ return float(response)
|
|
|
+ except:
|
|
|
+ return 0.5
|
|
|
+
|
|
|
+ def _evaluate_usefulness(self, query, answer):
|
|
|
+ prompt = (
|
|
|
+ "评估以下回答对于问题的有用性,给出0-1之间的分数。\n\n"
|
|
|
+ "问题:{}\n\n"
|
|
|
+ "回答:{}\n\n"
|
|
|
+ "有用性分数:".format(query, answer)
|
|
|
+ )
|
|
|
+ response = self._call_llm(prompt).strip()
|
|
|
+ try:
|
|
|
+ return float(response)
|
|
|
+ except:
|
|
|
+ return 0.5
|
|
|
+
|
|
|
+ def retrieve(self, query, k=10):
|
|
|
+ if not self._is_retrieval_needed(query):
|
|
|
+ return []
|
|
|
+
|
|
|
+ query_embedding = self.embedding_model.embed_query(query)
|
|
|
+ initial_results = self.vector_store.similarity_search(query_embedding, k * 2)
|
|
|
+
|
|
|
+ filtered_results = []
|
|
|
+ for doc, score in initial_results:
|
|
|
+ relevance = self._evaluate_relevance(query, doc)
|
|
|
+ if relevance >= self.relevance_threshold:
|
|
|
+ filtered_results.append((doc, relevance * score))
|
|
|
+
|
|
|
+ filtered_results.sort(key=lambda x: x[1], reverse=True)
|
|
|
+ return filtered_results[:k]
|
|
|
+
|
|
|
+ def generate(self, query, context):
|
|
|
+ prompt = (
|
|
|
+ "根据以下经过自反思筛选的参考文档,回答问题。\n\n"
|
|
|
+ "参考文档:\n{}\n\n"
|
|
|
+ "问题:{}\n\n"
|
|
|
+ "请确保回答有充分的文档支持,详细回答。".format(context, query)
|
|
|
+ )
|
|
|
+ return self._call_llm(prompt)
|
|
|
+
|
|
|
+
|
|
|
+class CorrectiveRAG(BaseRAG):
|
|
|
+ def __init__(self, correctness_threshold=0.6, **kwargs):
|
|
|
+ super().__init__(**kwargs)
|
|
|
+ self.correctness_threshold = correctness_threshold
|
|
|
+ self.web_search_results = []
|
|
|
+
|
|
|
+ def _evaluate_correctness(self, query, docs):
|
|
|
+ if not docs:
|
|
|
+ return 0.0
|
|
|
+
|
|
|
+ context = "\n\n".join([doc.page_content for doc, _ in docs[:5]])
|
|
|
+ prompt = (
|
|
|
+ "评估以下文档集合是否能够正确回答问题,给出0-1之间的分数。\n\n"
|
|
|
+ "问题:{}\n\n"
|
|
|
+ "文档集合:\n{}\n\n"
|
|
|
+ "正确性分数:".format(query, context)
|
|
|
+ )
|
|
|
+ response = self._call_llm(prompt).strip()
|
|
|
+ try:
|
|
|
+ return float(response)
|
|
|
+ except:
|
|
|
+ return 0.5
|
|
|
+
|
|
|
+ def _web_search(self, query):
|
|
|
+ return "[Web search results for: {}] - Simulated external knowledge".format(query)
|
|
|
+
|
|
|
+ def retrieve(self, query, k=10):
|
|
|
+ query_embedding = self.embedding_model.embed_query(query)
|
|
|
+ initial_results = self.vector_store.similarity_search(query_embedding, k)
|
|
|
+
|
|
|
+ correctness_score = self._evaluate_correctness(query, initial_results)
|
|
|
+
|
|
|
+ if correctness_score < self.correctness_threshold:
|
|
|
+ logger.info("CRAG: Correctness score {} below threshold, adding web search".format(correctness_score))
|
|
|
+ web_result = self._web_search(query)
|
|
|
+ web_doc = Document(
|
|
|
+ page_content=web_result,
|
|
|
+ metadata={"source": "web_search"}
|
|
|
+ )
|
|
|
+ initial_results.insert(0, (web_doc, 0.8))
|
|
|
+
|
|
|
+ return initial_results
|
|
|
+
|
|
|
+ def generate(self, query, context):
|
|
|
+ prompt = (
|
|
|
+ "根据以下参考文档(可能包含外部搜索结果),回答问题。\n\n"
|
|
|
+ "参考文档:\n{}\n\n"
|
|
|
+ "问题:{}\n\n"
|
|
|
+ "请详细回答。如果包含外部搜索结果,请注明。".format(context, query)
|
|
|
+ )
|
|
|
+ return self._call_llm(prompt)
|
|
|
+
|
|
|
+
|
|
|
+class FLARERAG(BaseRAG):
|
|
|
+ def __init__(self, max_iterations=3, **kwargs):
|
|
|
+ super().__init__(**kwargs)
|
|
|
+ self.max_iterations = max_iterations
|
|
|
+
|
|
|
+ def _need_more_retrieval(self, query, current_answer):
|
|
|
+ prompt = (
|
|
|
+ "基于当前已有的信息,判断是否还需要更多检索才能完整回答问题。\n\n"
|
|
|
+ "问题:{}\n\n"
|
|
|
+ "当前已有信息的回答:{}\n\n"
|
|
|
+ "是否需要更多检索(只回答是/否):".format(query, current_answer)
|
|
|
+ )
|
|
|
+ response = self._call_llm(prompt).strip().lower()
|
|
|
+ return "是" in response or "yes" in response
|
|
|
+
|
|
|
+ def _generate_next_query(self, query, current_answer):
|
|
|
+ prompt = (
|
|
|
+ "基于当前回答的不足,生成一个新的查询来补充信息。\n\n"
|
|
|
+ "原问题:{}\n\n"
|
|
|
+ "当前回答:{}\n\n"
|
|
|
+ "新查询:".format(query, current_answer)
|
|
|
+ )
|
|
|
+ return self._call_llm(prompt).strip()
|
|
|
+
|
|
|
+ def retrieve(self, query, k=10):
|
|
|
+ all_docs = {}
|
|
|
+ current_query = query
|
|
|
+ current_answer = ""
|
|
|
+
|
|
|
+ for iteration in range(self.max_iterations):
|
|
|
+ query_embedding = self.embedding_model.embed_query(current_query)
|
|
|
+ results = self.vector_store.similarity_search(query_embedding, k)
|
|
|
+
|
|
|
+ for doc, score in results:
|
|
|
+ doc_id = doc.page_content[:100]
|
|
|
+ if doc_id not in all_docs or score > all_docs[doc_id][1]:
|
|
|
+ all_docs[doc_id] = (doc, score)
|
|
|
+
|
|
|
+ context = self._format_context(list(all_docs.values()))
|
|
|
+ current_answer = self._generate_partial_answer(query, context)
|
|
|
+
|
|
|
+ if not self._need_more_retrieval(query, current_answer):
|
|
|
+ break
|
|
|
+
|
|
|
+ current_query = self._generate_next_query(query, current_answer)
|
|
|
+ logger.info("FLARE iteration {}, new query: {}".format(iteration + 1, current_query))
|
|
|
+
|
|
|
+ sorted_docs = sorted(all_docs.values(), key=lambda x: x[1], reverse=True)
|
|
|
+ return sorted_docs[:k]
|
|
|
+
|
|
|
+ def _generate_partial_answer(self, query, context):
|
|
|
+ prompt = (
|
|
|
+ "根据以下参考文档,给出问题的回答。\n\n"
|
|
|
+ "参考文档:\n{}\n\n"
|
|
|
+ "问题:{}\n\n"
|
|
|
+ "回答:".format(context, query)
|
|
|
+ )
|
|
|
+ return self._call_llm(prompt)
|
|
|
+
|
|
|
+ def generate(self, query, context):
|
|
|
+ return self._generate_partial_answer(query, context)
|
|
|
+
|
|
|
+
|
|
|
+class RAPTORRAG(BaseRAG):
|
|
|
+ def __init__(self, max_clusters=50, summary_length=256, num_tree_levels=2, **kwargs):
|
|
|
+ super().__init__(**kwargs)
|
|
|
+ self.max_clusters = max_clusters
|
|
|
+ self.summary_length = summary_length
|
|
|
+ self.num_tree_levels = num_tree_levels
|
|
|
+ self.hierarchical_docs = []
|
|
|
+
|
|
|
+ def _summarize_texts(self, texts):
|
|
|
+ summaries = []
|
|
|
+ batch_size = 5
|
|
|
+ for i in range(0, len(texts), batch_size):
|
|
|
+ batch = texts[i:i + batch_size]
|
|
|
+ combined = "\n\n".join(batch)
|
|
|
+ prompt = (
|
|
|
+ "请用一段话总结以下内容的核心要点({}字以内):\n\n"
|
|
|
+ "{}\n\n"
|
|
|
+ "总结:".format(self.summary_length, combined)
|
|
|
+ )
|
|
|
+ summary = self._call_llm(prompt)
|
|
|
+ summaries.append(summary)
|
|
|
+ return summaries
|
|
|
+
|
|
|
+ def _cluster_embeddings(self, embeddings):
|
|
|
+ from sklearn.cluster import AgglomerativeClustering
|
|
|
+ emb_np = np.array(embeddings)
|
|
|
+ n_clusters = min(self.max_clusters, len(embeddings))
|
|
|
+ if n_clusters < 2:
|
|
|
+ return [list(range(len(embeddings)))]
|
|
|
+
|
|
|
+ clustering = AgglomerativeClustering(n_clusters=n_clusters)
|
|
|
+ labels = clustering.fit_predict(emb_np)
|
|
|
+
|
|
|
+ clusters = {}
|
|
|
+ for i, label in enumerate(labels):
|
|
|
+ if label not in clusters:
|
|
|
+ clusters[label] = []
|
|
|
+ clusters[label].append(i)
|
|
|
+
|
|
|
+ return list(clusters.values())
|
|
|
+
|
|
|
+ def build_tree(self, documents):
|
|
|
+ self.index_documents(documents)
|
|
|
+
|
|
|
+ current_level_docs = documents
|
|
|
+ for level in range(self.num_tree_levels):
|
|
|
+ texts = [doc.page_content for doc in current_level_docs]
|
|
|
+ embeddings = self.embedding_model.embed_documents(texts)
|
|
|
+ clusters = self._cluster_embeddings(embeddings)
|
|
|
+
|
|
|
+ summary_docs = []
|
|
|
+ for cluster in clusters:
|
|
|
+ cluster_texts = [texts[i] for i in cluster]
|
|
|
+ summaries = self._summarize_texts(cluster_texts)
|
|
|
+ for summary in summaries:
|
|
|
+ summary_docs.append(Document(
|
|
|
+ page_content=summary,
|
|
|
+ metadata={"level": level + 1, "num_source_docs": len(cluster)}
|
|
|
+ ))
|
|
|
+
|
|
|
+ self.hierarchical_docs.extend(summary_docs)
|
|
|
+ if summary_docs:
|
|
|
+ summary_embeddings = self.embedding_model.embed_documents([d.page_content for d in summary_docs])
|
|
|
+ self.vector_store.add_documents(summary_docs, summary_embeddings)
|
|
|
+
|
|
|
+ current_level_docs = summary_docs
|
|
|
+ logger.info("RAPTOR level {}: {} summaries created".format(level + 1, len(summary_docs)))
|
|
|
+
|
|
|
+ def retrieve(self, query, k=10):
|
|
|
+ query_embedding = self.embedding_model.embed_query(query)
|
|
|
+
|
|
|
+ all_results = []
|
|
|
+ all_results.extend(self.vector_store.similarity_search(query_embedding, k))
|
|
|
+
|
|
|
+ sorted_results = sorted(all_results, key=lambda x: x[1], reverse=True)
|
|
|
+ return sorted_results[:k]
|
|
|
+
|
|
|
+ def generate(self, query, context):
|
|
|
+ prompt = (
|
|
|
+ "根据以下参考文档(包含多层次摘要信息),回答问题。\n\n"
|
|
|
+ "参考文档:\n{}\n\n"
|
|
|
+ "问题:{}\n\n"
|
|
|
+ "请详细回答。".format(context, query)
|
|
|
+ )
|
|
|
+ return self._call_llm(prompt)
|
|
|
+
|
|
|
+
|
|
|
+class EnsembleRAG(BaseRAG):
|
|
|
+ def __init__(self, methods=None, **kwargs):
|
|
|
+ super().__init__(**kwargs)
|
|
|
+ self.methods = methods or ["naive", "hybrid", "multi_query"]
|
|
|
+
|
|
|
+ def retrieve(self, query, k=10):
|
|
|
+ all_docs = {}
|
|
|
+
|
|
|
+ query_embedding = self.embedding_model.embed_query(query)
|
|
|
+ base_results = self.vector_store.similarity_search(query_embedding, k * 2)
|
|
|
+ for doc, score in base_results:
|
|
|
+ doc_id = doc.page_content[:100]
|
|
|
+ all_docs[doc_id] = (doc, score)
|
|
|
+
|
|
|
+ try:
|
|
|
+ from rank_bm25 import BM25Okapi
|
|
|
+ all_texts = [doc.page_content.split() for doc, _ in base_results]
|
|
|
+ bm25 = BM25Okapi(all_texts)
|
|
|
+ query_tokens = query.split()
|
|
|
+ bm25_scores = bm25.get_scores(query_tokens)
|
|
|
+ max_bm25 = max(bm25_scores) if bm25_scores.max() > 0 else 1.0
|
|
|
+
|
|
|
+ for i, (doc, _) in enumerate(base_results):
|
|
|
+ doc_id = doc.page_content[:100]
|
|
|
+ bm25_score = bm25_scores[i] / max_bm25
|
|
|
+ all_docs[doc_id] = (doc, all_docs[doc_id][1] * 0.5 + bm25_score * 0.5)
|
|
|
+ except:
|
|
|
+ pass
|
|
|
+
|
|
|
+ sorted_docs = sorted(all_docs.values(), key=lambda x: x[1], reverse=True)
|
|
|
+ return sorted_docs[:k]
|
|
|
+
|
|
|
+ def generate(self, query, context):
|
|
|
+ prompt = (
|
|
|
+ "根据以下参考文档(通过集成多种检索方法获取),回答问题。\n\n"
|
|
|
+ "参考文档:\n{}\n\n"
|
|
|
+ "问题:{}\n\n"
|
|
|
+ "请综合多种检索策略的结果,详细回答。".format(context, query)
|
|
|
+ )
|
|
|
+ return self._call_llm(prompt)
|
|
|
+
|
|
|
+
|
|
|
+class BidFieldExtractionRAG(BaseRAG):
|
|
|
+ def __init__(self, fields=None, **kwargs):
|
|
|
+ super().__init__(**kwargs)
|
|
|
+ self.fields = fields or [
|
|
|
+ "project_name", "project_code", "budget_amount", "currency",
|
|
|
+ "bid_deadline", "bid_open_time", "bid_location",
|
|
|
+ "purchaser_name", "purchaser_contact", "purchaser_phone",
|
|
|
+ "agency_name", "agency_contact", "agency_phone",
|
|
|
+ "qualification_requirements", "bid_bond_amount",
|
|
|
+ "performance_bond_amount", "warranty_period",
|
|
|
+ "delivery_time", "delivery_location", "payment_terms",
|
|
|
+ "evaluation_method", "scope_of_work"
|
|
|
+ ]
|
|
|
+
|
|
|
+ def retrieve(self, query, k=10):
|
|
|
+ all_docs = {}
|
|
|
+
|
|
|
+ for field in self.fields:
|
|
|
+ field_query = "{} {}".format(query, field)
|
|
|
+ field_embedding = self.embedding_model.embed_query(field_query)
|
|
|
+ results = self.vector_store.similarity_search(field_embedding, k // 3)
|
|
|
+ for doc, score in results:
|
|
|
+ doc_id = doc.page_content[:100]
|
|
|
+ if doc_id not in all_docs or score > all_docs[doc_id][1]:
|
|
|
+ all_docs[doc_id] = (doc, score)
|
|
|
+
|
|
|
+ query_embedding = self.embedding_model.embed_query(query)
|
|
|
+ direct_results = self.vector_store.similarity_search(query_embedding, k)
|
|
|
+ for doc, score in direct_results:
|
|
|
+ doc_id = doc.page_content[:100]
|
|
|
+ if doc_id not in all_docs or score > all_docs[doc_id][1]:
|
|
|
+ all_docs[doc_id] = (doc, score)
|
|
|
+
|
|
|
+ sorted_docs = sorted(all_docs.values(), key=lambda x: x[1], reverse=True)
|
|
|
+ return sorted_docs[:k]
|
|
|
+
|
|
|
+ def generate(self, query, context):
|
|
|
+ fields_str = "\n".join(["- {}".format(f) for f in self.fields])
|
|
|
+ prompt = """你是一个招投标领域的专家。请根据提供的文档内容,提取以下字段信息:
|
|
|
+
|
|
|
+{}
|
|
|
+
|
|
|
+文档内容:
|
|
|
+{}
|
|
|
+
|
|
|
+请以JSON格式返回提取结果。如果某个字段无法从文档中提取,请返回null。
|
|
|
+
|
|
|
+JSON格式示例:
|
|
|
+{{
|
|
|
+ "project_name": "项目名称",
|
|
|
+ "budget_amount": 1000000,
|
|
|
+ ...
|
|
|
+}}""".format(fields_str, context)
|
|
|
+ return self._call_llm(prompt)
|
|
|
+
|
|
|
+
|
|
|
+class TableAwareRAG(BaseRAG):
|
|
|
+ def __init__(self, **kwargs):
|
|
|
+ super().__init__(**kwargs)
|
|
|
+
|
|
|
+ def _extract_table_info(self, doc):
|
|
|
+ content = doc.page_content
|
|
|
+
|
|
|
+ table_pattern = r'(\|.*\|[\r\n]+)'
|
|
|
+ tables = re.findall(table_pattern, content)
|
|
|
+
|
|
|
+ if tables:
|
|
|
+ return "[TABLE DETECTED]\n{}".format(''.join(tables))
|
|
|
+
|
|
|
+ numbered_pattern = r'(\d+[\.、].*?(?=\d+[\.、]|$))'
|
|
|
+ numbered_items = re.findall(numbered_pattern, content, re.DOTALL)
|
|
|
+
|
|
|
+ if numbered_items:
|
|
|
+ return "[STRUCTURED LIST DETECTED]\n{}".format(''.join(numbered_items))
|
|
|
+
|
|
|
+ return content
|
|
|
+
|
|
|
+ def retrieve(self, query, k=10):
|
|
|
+ query_embedding = self.embedding_model.embed_query(query)
|
|
|
+ initial_results = self.vector_store.similarity_search(query_embedding, k * 2)
|
|
|
+
|
|
|
+ enhanced_results = []
|
|
|
+ for doc, score in initial_results:
|
|
|
+ enhanced_content = self._extract_table_info(doc)
|
|
|
+ enhanced_doc = Document(
|
|
|
+ page_content=enhanced_content,
|
|
|
+ metadata=dict(doc.metadata, has_table="[TABLE" in enhanced_content)
|
|
|
+ )
|
|
|
+ table_bonus = 0.1 if "[TABLE" in enhanced_content else 0.0
|
|
|
+ enhanced_results.append((enhanced_doc, score + table_bonus))
|
|
|
+
|
|
|
+ enhanced_results.sort(key=lambda x: x[1], reverse=True)
|
|
|
+ return enhanced_results[:k]
|
|
|
+
|
|
|
+ def generate(self, query, context):
|
|
|
+ prompt = (
|
|
|
+ "根据以下参考文档(包含表格和结构化数据),回答问题。\n\n"
|
|
|
+ "参考文档:\n{}\n\n"
|
|
|
+ "问题:{}\n\n"
|
|
|
+ "请特别注意表格和结构化数据中的信息,详细回答。".format(context, query)
|
|
|
+ )
|
|
|
+ return self._call_llm(prompt)
|
|
|
+
|
|
|
+
|
|
|
+class GraphRAG(BaseRAG):
|
|
|
+ def __init__(self, **kwargs):
|
|
|
+ super().__init__(**kwargs)
|
|
|
+ self.graph = {}
|
|
|
+
|
|
|
+ def _extract_entities(self, text):
|
|
|
+ entities = re.findall(r'[一-龥]{2,10}(?:公司|单位|招标|投标|项目|金额|时间)', text)
|
|
|
+ return list(set(entities))
|
|
|
+
|
|
|
+ def build_graph(self, documents):
|
|
|
+ self.index_documents(documents)
|
|
|
+
|
|
|
+ for doc in documents:
|
|
|
+ entities = self._extract_entities(doc.page_content)
|
|
|
+ for entity in entities:
|
|
|
+ if entity not in self.graph:
|
|
|
+ self.graph[entity] = []
|
|
|
+ self.graph[entity].append(doc)
|
|
|
+
|
|
|
+ logger.info("Graph built with {} entities".format(len(self.graph)))
|
|
|
+
|
|
|
+ def retrieve(self, query, k=10):
|
|
|
+ query_embedding = self.embedding_model.embed_query(query)
|
|
|
+ semantic_results = self.vector_store.similarity_search(query_embedding, k)
|
|
|
+
|
|
|
+ query_entities = self._extract_entities(query)
|
|
|
+ graph_docs = {}
|
|
|
+ for entity in query_entities:
|
|
|
+ if entity in self.graph:
|
|
|
+ for doc in self.graph[entity]:
|
|
|
+ doc_id = doc.page_content[:100]
|
|
|
+ if doc_id not in graph_docs:
|
|
|
+ graph_docs[doc_id] = (doc, 0.7)
|
|
|
+ else:
|
|
|
+ graph_docs[doc_id] = (doc, graph_docs[doc_id][1] + 0.1)
|
|
|
+
|
|
|
+ all_docs = {}
|
|
|
+ for doc, score in semantic_results:
|
|
|
+ doc_id = doc.page_content[:100]
|
|
|
+ all_docs[doc_id] = (doc, score)
|
|
|
+
|
|
|
+ for doc_id, (doc, score) in graph_docs.items():
|
|
|
+ if doc_id in all_docs:
|
|
|
+ all_docs[doc_id] = (doc, all_docs[doc_id][1] * 0.6 + score * 0.4)
|
|
|
+ else:
|
|
|
+ all_docs[doc_id] = (doc, score)
|
|
|
+
|
|
|
+ sorted_docs = sorted(all_docs.values(), key=lambda x: x[1], reverse=True)
|
|
|
+ return sorted_docs[:k]
|
|
|
+
|
|
|
+ def generate(self, query, context):
|
|
|
+ prompt = (
|
|
|
+ "根据以下参考文档(结合知识图谱和语义检索),回答问题。\n\n"
|
|
|
+ "参考文档:\n{}\n\n"
|
|
|
+ "问题:{}\n\n"
|
|
|
+ "请详细回答。".format(context, query)
|
|
|
+ )
|
|
|
+ return self._call_llm(prompt)
|
|
|
+
|
|
|
+
|
|
|
+class BM25RAG(BaseRAG):
|
|
|
+ def __init__(self, llm_client=None, llm_model="gpt-4o", **kwargs):
|
|
|
+ self.embedding_model = None
|
|
|
+ self.vector_store = None
|
|
|
+ self.llm_client = llm_client
|
|
|
+ self.llm_model = llm_model
|
|
|
+ self.name = self.__class__.__name__
|
|
|
+ from rank_bm25 import BM25Okapi
|
|
|
+ self.bm25 = None
|
|
|
+ self._all_texts = []
|
|
|
+ self._all_documents = []
|
|
|
+
|
|
|
+ def index_documents(self, documents):
|
|
|
+ self._all_documents = documents
|
|
|
+ self._all_texts = [doc.page_content.split() for doc in documents]
|
|
|
+ from rank_bm25 import BM25Okapi
|
|
|
+ self.bm25 = BM25Okapi(self._all_texts)
|
|
|
+ logger.info("BM25 index built with {} documents".format(len(documents)))
|
|
|
+
|
|
|
+ def retrieve(self, query, k=10):
|
|
|
+ if self.bm25 is None:
|
|
|
+ return []
|
|
|
+
|
|
|
+ query_tokens = query.split()
|
|
|
+ scores = self.bm25.get_scores(query_tokens)
|
|
|
+
|
|
|
+ scored_docs = []
|
|
|
+ for i, score in enumerate(scores):
|
|
|
+ if score > 0:
|
|
|
+ scored_docs.append((self._all_documents[i], float(score)))
|
|
|
+
|
|
|
+ scored_docs.sort(key=lambda x: x[1], reverse=True)
|
|
|
+ return scored_docs[:k]
|
|
|
+
|
|
|
+ def generate(self, query, context):
|
|
|
+ prompt = (
|
|
|
+ "根据以下参考文档(BM25关键词检索),回答问题。\n\n"
|
|
|
+ "参考文档:\n{}\n\n"
|
|
|
+ "问题:{}\n\n"
|
|
|
+ "请详细回答。".format(context, query)
|
|
|
+ )
|
|
|
+ return self._call_llm(prompt)
|
|
|
+
|
|
|
+ def query(self, query, k=10):
|
|
|
+ start_total = time.time()
|
|
|
+
|
|
|
+ t0 = time.time()
|
|
|
+ docs = self.retrieve(query, k)
|
|
|
+ retrieval_time = time.time() - t0
|
|
|
+
|
|
|
+ context = self._format_context(docs)
|
|
|
+
|
|
|
+ t1 = time.time()
|
|
|
+ answer = self.generate(query, context)
|
|
|
+ generation_time = time.time() - t1
|
|
|
+
|
|
|
+ total_time = time.time() - start_total
|
|
|
+
|
|
|
+ return RAGResult(
|
|
|
+ answer=answer,
|
|
|
+ retrieved_docs=docs,
|
|
|
+ latency_retrieval=retrieval_time,
|
|
|
+ latency_generation=generation_time,
|
|
|
+ latency_total=total_time,
|
|
|
+ metadata={"method": self.name, "num_context_tokens": len(context)},
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+class TFIDFRAG(BaseRAG):
|
|
|
+ def __init__(self, llm_client=None, llm_model="gpt-4o", **kwargs):
|
|
|
+ self.embedding_model = None
|
|
|
+ self.vector_store = None
|
|
|
+ self.llm_client = llm_client
|
|
|
+ self.llm_model = llm_model
|
|
|
+ self.name = self.__class__.__name__
|
|
|
+ self.vectorizer = None
|
|
|
+ self.tfidf_matrix = None
|
|
|
+ self._all_documents = []
|
|
|
+
|
|
|
+ def index_documents(self, documents):
|
|
|
+ self._all_documents = documents
|
|
|
+ texts = [doc.page_content for doc in documents]
|
|
|
+ from sklearn.feature_extraction.text import TfidfVectorizer
|
|
|
+ self.vectorizer = TfidfVectorizer()
|
|
|
+ self.tfidf_matrix = self.vectorizer.fit_transform(texts)
|
|
|
+ logger.info("TF-IDF index built with {} documents, vocab size: {}".format(len(documents), len(self.vectorizer.vocabulary_)))
|
|
|
+
|
|
|
+ def retrieve(self, query, k=10):
|
|
|
+ if self.tfidf_matrix is None:
|
|
|
+ return []
|
|
|
+
|
|
|
+ from sklearn.metrics.pairwise import cosine_similarity
|
|
|
+ query_vec = self.vectorizer.transform([query])
|
|
|
+ scores = cosine_similarity(query_vec, self.tfidf_matrix).flatten()
|
|
|
+
|
|
|
+ k = min(k, len(scores))
|
|
|
+ top_indices = np.argsort(scores)[::-1][:k]
|
|
|
+
|
|
|
+ results = []
|
|
|
+ for idx in top_indices:
|
|
|
+ if scores[idx] > 0:
|
|
|
+ results.append((self._all_documents[idx], float(scores[idx])))
|
|
|
+
|
|
|
+ return results
|
|
|
+
|
|
|
+ def generate(self, query, context):
|
|
|
+ prompt = (
|
|
|
+ "根据以下参考文档(TF-IDF关键词检索),回答问题。\n\n"
|
|
|
+ "参考文档:\n{}\n\n"
|
|
|
+ "问题:{}\n\n"
|
|
|
+ "请详细回答。".format(context, query)
|
|
|
+ )
|
|
|
+ return self._call_llm(prompt)
|
|
|
+
|
|
|
+ def query(self, query, k=10):
|
|
|
+ start_total = time.time()
|
|
|
+
|
|
|
+ t0 = time.time()
|
|
|
+ docs = self.retrieve(query, k)
|
|
|
+ retrieval_time = time.time() - t0
|
|
|
+
|
|
|
+ context = self._format_context(docs)
|
|
|
+
|
|
|
+ t1 = time.time()
|
|
|
+ answer = self.generate(query, context)
|
|
|
+ generation_time = time.time() - t1
|
|
|
+
|
|
|
+ total_time = time.time() - start_total
|
|
|
+
|
|
|
+ return RAGResult(
|
|
|
+ answer=answer,
|
|
|
+ retrieved_docs=docs,
|
|
|
+ latency_retrieval=retrieval_time,
|
|
|
+ latency_generation=generation_time,
|
|
|
+ latency_total=total_time,
|
|
|
+ metadata={"method": self.name, "num_context_tokens": len(context)},
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+class KeywordRAG(BaseRAG):
|
|
|
+ def __init__(self, search_method="bm25", llm_client=None, llm_model="gpt-4o", **kwargs):
|
|
|
+ self.embedding_model = None
|
|
|
+ self.vector_store = None
|
|
|
+ self.llm_client = llm_client
|
|
|
+ self.llm_model = llm_model
|
|
|
+ self.name = self.__class__.__name__
|
|
|
+ self.search_method = search_method
|
|
|
+ self._all_documents = []
|
|
|
+ self._all_texts = []
|
|
|
+ self.bm25 = None
|
|
|
+ self.tfidf_matrix = None
|
|
|
+ self.vectorizer = None
|
|
|
+
|
|
|
+ def index_documents(self, documents):
|
|
|
+ self._all_documents = documents
|
|
|
+ self._all_texts = [doc.page_content for doc in documents]
|
|
|
+
|
|
|
+ if self.search_method == "bm25":
|
|
|
+ from rank_bm25 import BM25Okapi
|
|
|
+ tokenized = [t.split() for t in self._all_texts]
|
|
|
+ self.bm25 = BM25Okapi(tokenized)
|
|
|
+ elif self.search_method == "tfidf":
|
|
|
+ from sklearn.feature_extraction.text import TfidfVectorizer
|
|
|
+ self.vectorizer = TfidfVectorizer()
|
|
|
+ self.tfidf_matrix = self.vectorizer.fit_transform(self._all_texts)
|
|
|
+
|
|
|
+ logger.info("KeywordRAG ({}) index built with {} documents".format(self.search_method, len(documents)))
|
|
|
+
|
|
|
+ def retrieve(self, query, k=10):
|
|
|
+ if self.search_method == "bm25":
|
|
|
+ query_tokens = query.split()
|
|
|
+ scores = self.bm25.get_scores(query_tokens)
|
|
|
+ elif self.search_method == "tfidf":
|
|
|
+ from sklearn.metrics.pairwise import cosine_similarity
|
|
|
+ query_vec = self.vectorizer.transform([query])
|
|
|
+ scores = cosine_similarity(query_vec, self.tfidf_matrix).flatten()
|
|
|
+ else:
|
|
|
+ return []
|
|
|
+
|
|
|
+ k = min(k, len(scores))
|
|
|
+ top_indices = np.argsort(scores)[::-1][:k]
|
|
|
+
|
|
|
+ results = []
|
|
|
+ for idx in top_indices:
|
|
|
+ if scores[idx] > 0:
|
|
|
+ results.append((self._all_documents[idx], float(scores[idx])))
|
|
|
+
|
|
|
+ return results
|
|
|
+
|
|
|
+ def generate(self, query, context):
|
|
|
+ prompt = (
|
|
|
+ "根据以下参考文档(关键词检索:{}),回答问题。\n\n"
|
|
|
+ "参考文档:\n{}\n\n"
|
|
|
+ "问题:{}\n\n"
|
|
|
+ "请详细回答。".format(self.search_method, context, query)
|
|
|
+ )
|
|
|
+ return self._call_llm(prompt)
|
|
|
+
|
|
|
+ def query(self, query, k=10):
|
|
|
+ start_total = time.time()
|
|
|
+
|
|
|
+ t0 = time.time()
|
|
|
+ docs = self.retrieve(query, k)
|
|
|
+ retrieval_time = time.time() - t0
|
|
|
+
|
|
|
+ context = self._format_context(docs)
|
|
|
+
|
|
|
+ t1 = time.time()
|
|
|
+ answer = self.generate(query, context)
|
|
|
+ generation_time = time.time() - t1
|
|
|
+
|
|
|
+ total_time = time.time() - start_total
|
|
|
+
|
|
|
+ return RAGResult(
|
|
|
+ answer=answer,
|
|
|
+ retrieved_docs=docs,
|
|
|
+ latency_retrieval=retrieval_time,
|
|
|
+ latency_generation=generation_time,
|
|
|
+ latency_total=total_time,
|
|
|
+ metadata={"method": self.name, "num_context_tokens": len(context)},
|
|
|
+ )
|