rag_methods.py 51 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377
  1. from abc import ABC, abstractmethod
  2. from typing import List, Dict, Any, Optional, Tuple
  3. import time
  4. import re
  5. import numpy as np
  6. from loguru import logger
  7. from .document_processor import Document
  8. from .embedding_models import BaseEmbedding
  9. from .vector_stores import BaseVectorStore
  10. class RAGResult(object):
  11. def __init__(self, answer="", retrieved_docs=None, latency_retrieval=0.0,
  12. latency_generation=0.0, latency_total=0.0, metadata=None):
  13. self.answer = answer
  14. self.retrieved_docs = retrieved_docs if retrieved_docs is not None else []
  15. self.latency_retrieval = latency_retrieval
  16. self.latency_generation = latency_generation
  17. self.latency_total = latency_total
  18. self.metadata = metadata if metadata is not None else {}
  19. def to_dict(self):
  20. return {
  21. "answer": self.answer,
  22. "num_docs_retrieved": len(self.retrieved_docs),
  23. "latency_retrieval": round(self.latency_retrieval, 3),
  24. "latency_generation": round(self.latency_generation, 3),
  25. "latency_total": round(self.latency_total, 3),
  26. "metadata": self.metadata or {},
  27. }
  28. class BaseRAG(ABC):
  29. def __init__(self, embedding_model=None, vector_store=None, llm_client=None, llm_model="gpt-4o", **kwargs):
  30. self.embedding_model = embedding_model
  31. self.vector_store = vector_store
  32. self.llm_client = llm_client
  33. self.llm_model = llm_model
  34. self.name = self.__class__.__name__
  35. def index_documents(self, documents):
  36. texts = [doc.page_content for doc in documents]
  37. embeddings = self.embedding_model.embed_documents(texts)
  38. self.vector_store.add_documents(documents, embeddings)
  39. def _call_llm(self, prompt, system_prompt=None):
  40. messages = []
  41. if system_prompt:
  42. messages.append({"role": "system", "content": system_prompt})
  43. messages.append({"role": "user", "content": prompt})
  44. response = self.llm_client.chat.completions.create(
  45. model=self.llm_model,
  46. messages=messages,
  47. temperature=0.1,
  48. max_tokens=2048,
  49. )
  50. return response.choices[0].message.content
  51. def _format_context(self, docs):
  52. context_parts = []
  53. for i, (doc, score) in enumerate(docs, 1):
  54. source = doc.metadata.get("source", "unknown")
  55. context_parts.append("[{}] (Score: {:.3f}, Source: {})\n{}".format(i, score, source, doc.page_content))
  56. return "\n\n---\n\n".join(context_parts)
  57. @abstractmethod
  58. def retrieve(self, query, k=10):
  59. pass
  60. @abstractmethod
  61. def generate(self, query, context):
  62. pass
  63. def query(self, query, k=10):
  64. start_total = time.time()
  65. t0 = time.time()
  66. docs = self.retrieve(query, k)
  67. retrieval_time = time.time() - t0
  68. context = self._format_context(docs)
  69. t1 = time.time()
  70. answer = self.generate(query, context)
  71. generation_time = time.time() - t1
  72. total_time = time.time() - start_total
  73. return RAGResult(
  74. answer=answer,
  75. retrieved_docs=docs,
  76. latency_retrieval=retrieval_time,
  77. latency_generation=generation_time,
  78. latency_total=total_time,
  79. metadata={"method": self.name, "num_context_tokens": len(context)},
  80. )
  81. class NaiveRAG(BaseRAG):
  82. def __init__(self, retrieval_prompt_template=None, **kwargs):
  83. super().__init__(**kwargs)
  84. self.retrieval_prompt_template = retrieval_prompt_template or (
  85. "根据以下参考文档,回答问题。\n\n"
  86. "参考文档:\n{context}\n\n"
  87. "问题:{query}\n\n"
  88. "请详细回答,如果参考文档中没有相关信息,请说明无法从文档中找到答案。"
  89. )
  90. def retrieve(self, query, k=10):
  91. query_embedding = self.embedding_model.embed_query(query)
  92. return self.vector_store.similarity_search(query_embedding, k)
  93. def generate(self, query, context):
  94. prompt = self.retrieval_prompt_template.format(context=context, query=query)
  95. return self._call_llm(prompt)
  96. class RerankRAG(BaseRAG):
  97. def __init__(self, rerank_model=None, rerank_top_k=5, **kwargs):
  98. super().__init__(**kwargs)
  99. self.rerank_model = rerank_model
  100. self.rerank_top_k = rerank_top_k
  101. self.initial_k = 20
  102. def retrieve(self, query, k=10):
  103. query_embedding = self.embedding_model.embed_query(query)
  104. initial_docs = self.vector_store.similarity_search(query_embedding, self.initial_k)
  105. if self.rerank_model and len(initial_docs) > 0:
  106. texts = [doc.page_content for doc, _ in initial_docs]
  107. pairs = [(query, text) for text in texts]
  108. scores = self.rerank_model.compute_score(pairs)
  109. if isinstance(scores, (int, float)):
  110. scores = [scores]
  111. reranked = list(zip(initial_docs, scores))
  112. reranked.sort(key=lambda x: x[1], reverse=True)
  113. return [(doc, float(score)) for (doc, _), score in reranked[:k]]
  114. return initial_docs[:k]
  115. def generate(self, query, context):
  116. prompt = (
  117. "根据以下经过重排序的参考文档,回答问题。\n\n"
  118. "参考文档:\n{}\n\n"
  119. "问题:{}\n\n"
  120. "请详细回答。".format(context, query)
  121. )
  122. return self._call_llm(prompt)
  123. class ParentDocumentRAG(BaseRAG):
  124. def __init__(self, parent_chunk_size=1500, **kwargs):
  125. super().__init__(**kwargs)
  126. self.parent_chunk_size = parent_chunk_size
  127. self.parent_docs = []
  128. self.child_to_parent = {}
  129. def index_documents(self, documents):
  130. self.parent_docs = []
  131. self.child_to_parent = {}
  132. for i, doc in enumerate(documents):
  133. words = doc.page_content.split()
  134. for j in range(0, len(words), self.parent_chunk_size):
  135. parent_text = " ".join(words[j:j + self.parent_chunk_size])
  136. parent_doc = Document(
  137. page_content=parent_text,
  138. metadata=dict(doc.metadata, chunk_index=j, is_parent=True)
  139. )
  140. self.parent_docs.append(parent_doc)
  141. child_texts = []
  142. for doc in documents:
  143. words = doc.page_content.split()
  144. for j in range(0, len(words), self.parent_chunk_size):
  145. chunk_words = words[j:j + self.parent_chunk_size]
  146. for k in range(0, len(chunk_words), 512):
  147. child_text = " ".join(chunk_words[k:k + 512])
  148. child_id = len(child_texts)
  149. child_texts.append(child_text)
  150. self.child_to_parent[child_id] = child_id // 3
  151. if child_texts:
  152. embeddings = self.embedding_model.embed_documents(child_texts)
  153. child_docs = [
  154. Document(page_content=text, metadata={"is_parent": False})
  155. for text in child_texts
  156. ]
  157. self.vector_store.add_documents(child_docs, embeddings)
  158. logger.info("ParentDocumentRAG: {} parents, {} children".format(len(self.parent_docs), len(child_texts)))
  159. def retrieve(self, query, k=10):
  160. query_embedding = self.embedding_model.embed_query(query)
  161. child_results = self.vector_store.similarity_search(query_embedding, k * 2)
  162. parent_map = {}
  163. for child_doc, score in child_results:
  164. for i, parent_doc in enumerate(self.parent_docs):
  165. if child_doc.page_content[:50] in parent_doc.page_content:
  166. if i not in parent_map or score > parent_map[i]:
  167. parent_map[i] = (parent_doc, score)
  168. break
  169. parent_results = sorted(parent_map.values(), key=lambda x: x[1], reverse=True)
  170. return parent_results[:k]
  171. def generate(self, query, context):
  172. prompt = (
  173. "根据以下参考文档(父子文档检索,包含完整上下文),回答问题。\n\n"
  174. "参考文档:\n{}\n\n"
  175. "问题:{}\n\n"
  176. "请详细回答。".format(context, query)
  177. )
  178. return self._call_llm(prompt)
  179. class LLMFilterRAG(BaseRAG):
  180. def __init__(self, filter_threshold=0.5, **kwargs):
  181. super().__init__(**kwargs)
  182. self.filter_threshold = filter_threshold
  183. def _score_relevance(self, query, doc):
  184. prompt = (
  185. "评估以下文档与问题的相关性,给出0-1之间的分数。只返回分数数字。\n\n"
  186. "问题:{}\n\n"
  187. "文档:{}\n\n"
  188. "相关性分数:".format(query, doc.page_content[:300])
  189. )
  190. try:
  191. response = self._call_llm(prompt).strip()
  192. return float(response)
  193. except:
  194. return 0.5
  195. def retrieve(self, query, k=10):
  196. query_embedding = self.embedding_model.embed_query(query)
  197. initial_results = self.vector_store.similarity_search(query_embedding, k * 3)
  198. filtered_results = []
  199. for doc, score in initial_results:
  200. relevance = self._score_relevance(query, doc)
  201. if relevance >= self.filter_threshold:
  202. filtered_results.append((doc, score * relevance))
  203. filtered_results.sort(key=lambda x: x[1], reverse=True)
  204. return filtered_results[:k]
  205. def generate(self, query, context):
  206. prompt = (
  207. "根据以下经过LLM精选的参考文档,回答问题。\n\n"
  208. "参考文档:\n{}\n\n"
  209. "问题:{}\n\n"
  210. "请详细回答。".format(context, query)
  211. )
  212. return self._call_llm(prompt)
  213. class QueryRoutingRAG(BaseRAG):
  214. def __init__(self, **kwargs):
  215. super().__init__(**kwargs)
  216. def _classify_query(self, query):
  217. categories = [
  218. "budget", "deadline", "qualification", "evaluation",
  219. "payment", "warranty", "delivery", "contact", "scope"
  220. ]
  221. categories_str = ", ".join(categories)
  222. prompt = (
  223. "将以下问题分类到以下类别之一:{}\n\n"
  224. "问题:{}\n\n"
  225. "类别:".format(categories_str, query)
  226. )
  227. return self._call_llm(prompt).strip().lower()
  228. def retrieve(self, query, k=10):
  229. query_embedding = self.embedding_model.embed_query(query)
  230. semantic_results = self.vector_store.similarity_search(query_embedding, k)
  231. category = self._classify_query(query)
  232. logger.info("QueryRouting: category={}".format(category))
  233. category_boost = {}
  234. for doc, score in semantic_results:
  235. content_lower = doc.page_content.lower()
  236. category_keywords = {
  237. "budget": ["预算", "金额", "价格", "费用", "报价"],
  238. "deadline": ["截止", "时间", "日期", "开标"],
  239. "qualification": ["资格", "要求", "证书", "业绩"],
  240. "evaluation": ["评标", "评价", "分数", "方法"],
  241. "payment": ["付款", "结算", "进度", "保证金"],
  242. "warranty": ["质保", "维修", "售后", "服务"],
  243. "delivery": ["交货", "工期", "交付", "地点"],
  244. "contact": ["联系人", "电话", "邮箱"],
  245. "scope": ["范围", "内容", "清单", "设备"],
  246. }
  247. keywords = category_keywords.get(category, [])
  248. boost = sum(1 for kw in keywords if kw in content_lower) * 0.1
  249. category_boost[id(doc)] = boost
  250. enhanced_results = []
  251. for doc, score in semantic_results:
  252. boost = category_boost.get(id(doc), 0)
  253. enhanced_results.append((doc, score + boost))
  254. enhanced_results.sort(key=lambda x: x[1], reverse=True)
  255. return enhanced_results[:k]
  256. def generate(self, query, context):
  257. prompt = (
  258. "根据以下经过查询路由的参考文档,回答问题。\n\n"
  259. "参考文档:\n{}\n\n"
  260. "问题:{}\n\n"
  261. "请详细回答。".format(context, query)
  262. )
  263. return self._call_llm(prompt)
  264. class MetadataFilterRAG(BaseRAG):
  265. def __init__(self, **kwargs):
  266. super().__init__(**kwargs)
  267. def retrieve(self, query, k=10, metadata_filter=None):
  268. query_embedding = self.embedding_model.embed_query(query)
  269. all_results = self.vector_store.similarity_search(query_embedding, k * 3)
  270. if metadata_filter:
  271. filtered = []
  272. for doc, score in all_results:
  273. match = all(
  274. doc.metadata.get(key) == value
  275. for key, value in metadata_filter.items()
  276. )
  277. if match:
  278. filtered.append((doc, score))
  279. return filtered[:k]
  280. return all_results[:k]
  281. def generate(self, query, context):
  282. prompt = (
  283. "根据以下经过元数据筛选的参考文档,回答问题。\n\n"
  284. "参考文档:\n{}\n\n"
  285. "问题:{}\n\n"
  286. "请详细回答。".format(context, query)
  287. )
  288. return self._call_llm(prompt)
  289. class AdaptiveRAG(BaseRAG):
  290. def __init__(self, **kwargs):
  291. super().__init__(**kwargs)
  292. def _determine_strategy(self, query):
  293. prompt = (
  294. "根据问题类型,选择最合适的检索策略。\n"
  295. "可选策略:\n"
  296. "- semantic: 语义检索,适合概念性问题\n"
  297. "- keyword: 关键词检索,适合精确实体匹配\n"
  298. "- multi: 多路检索,适合复杂综合问题\n\n"
  299. "问题:{}\n\n"
  300. "策略名称(只返回name):".format(query)
  301. )
  302. return self._call_llm(prompt).strip().lower()
  303. def retrieve(self, query, k=10):
  304. strategy = self._determine_strategy(query)
  305. logger.info("AdaptiveRAG: strategy={}".format(strategy))
  306. if strategy in ["keyword", "exact"]:
  307. try:
  308. from rank_bm25 import BM25Okapi
  309. except:
  310. BM25Okapi = None
  311. if BM25Okapi:
  312. all_docs = self.vector_store.documents if hasattr(self.vector_store, 'documents') else []
  313. if all_docs:
  314. texts = [doc.page_content.split() for doc in all_docs]
  315. bm25 = BM25Okapi(texts)
  316. scores = bm25.get_scores(query.split())
  317. results = []
  318. for i, score in enumerate(scores):
  319. if score > 0:
  320. results.append((all_docs[i], float(score)))
  321. results.sort(key=lambda x: x[1], reverse=True)
  322. return results[:k]
  323. query_embedding = self.embedding_model.embed_query(query)
  324. semantic_results = self.vector_store.similarity_search(query_embedding, k)
  325. if strategy in ["multi", "comprehensive"]:
  326. all_results = {}
  327. for doc, score in semantic_results:
  328. doc_id = doc.page_content[:100]
  329. all_results[doc_id] = (doc, score)
  330. try:
  331. from rank_bm25 import BM25Okapi
  332. all_docs = self.vector_store.documents if hasattr(self.vector_store, 'documents') else []
  333. if all_docs:
  334. texts = [doc.page_content.split() for doc in all_docs]
  335. bm25 = BM25Okapi(texts)
  336. bm25_scores = bm25.get_scores(query.split())
  337. max_bm25 = max(bm25_scores) if max(bm25_scores) > 0 else 1.0
  338. for i, doc in enumerate(all_docs):
  339. doc_id = doc.page_content[:100]
  340. norm_bm25 = bm25_scores[i] / max_bm25
  341. if doc_id in all_results:
  342. all_results[doc_id] = (doc, all_results[doc_id][1] * 0.6 + norm_bm25 * 0.4)
  343. else:
  344. all_results[doc_id] = (doc, norm_bm25)
  345. except:
  346. pass
  347. sorted_results = sorted(all_results.values(), key=lambda x: x[1], reverse=True)
  348. return sorted_results[:k]
  349. return semantic_results[:k]
  350. def generate(self, query, context):
  351. prompt = (
  352. "根据以下经过自适应检索的参考文档,回答问题。\n\n"
  353. "参考文档:\n{}\n\n"
  354. "问题:{}\n\n"
  355. "请详细回答。".format(context, query)
  356. )
  357. return self._call_llm(prompt)
  358. class HybridSearchRAG(BaseRAG):
  359. def __init__(self, bm25_index=None, semantic_weight=0.5, **kwargs):
  360. super().__init__(**kwargs)
  361. self.bm25_index = bm25_index
  362. self.semantic_weight = semantic_weight
  363. self._all_texts = []
  364. def index_documents(self, documents):
  365. super().index_documents(documents)
  366. if self.bm25_index is not None:
  367. from rank_bm25 import BM25Okapi
  368. self._all_texts = [doc.page_content.split() for doc in documents]
  369. self.bm25_index = BM25Okapi(self._all_texts)
  370. def retrieve(self, query, k=10):
  371. query_embedding = self.embedding_model.embed_query(query)
  372. semantic_results = self.vector_store.similarity_search(query_embedding, k * 2)
  373. if self.bm25_index and self._all_texts:
  374. from rank_bm25 import BM25Okapi
  375. query_tokens = query.split()
  376. bm25_scores = self.bm25_index.get_scores(query_tokens)
  377. max_bm25 = max(bm25_scores) if bm25_scores.max() > 0 else 1.0
  378. normalized_bm25 = bm25_scores / max_bm25
  379. combined = []
  380. for i, (doc, sem_score) in enumerate(semantic_results):
  381. bm25_score = normalized_bm25[i] if i < len(normalized_bm25) else 0.0
  382. combined_score = (
  383. self.semantic_weight * sem_score +
  384. (1 - self.semantic_weight) * bm25_score
  385. )
  386. combined.append((doc, combined_score))
  387. combined.sort(key=lambda x: x[1], reverse=True)
  388. return combined[:k]
  389. return semantic_results[:k]
  390. def generate(self, query, context):
  391. prompt = (
  392. "根据以下参考文档(结合语义搜索和BM25关键词搜索),回答问题。\n\n"
  393. "参考文档:\n{}\n\n"
  394. "问题:{}\n\n"
  395. "请详细回答。".format(context, query)
  396. )
  397. return self._call_llm(prompt)
  398. class MultiQueryRAG(BaseRAG):
  399. def __init__(self, num_queries=3, **kwargs):
  400. super().__init__(**kwargs)
  401. self.num_queries = num_queries
  402. def _generate_queries(self, query):
  403. prompt = (
  404. "请将以下问题扩展为{}个不同角度的问题,"
  405. "用于检索更全面的信息。每行一个问题。\n\n"
  406. "原问题:{}\n\n"
  407. "扩展问题:".format(self.num_queries, query)
  408. )
  409. response = self._call_llm(prompt)
  410. queries = [q.strip() for q in response.strip().split("\n") if q.strip()]
  411. queries.insert(0, query)
  412. return queries[:self.num_queries + 1]
  413. def retrieve(self, query, k=10):
  414. queries = self._generate_queries(query)
  415. all_docs = {}
  416. for q in queries:
  417. q_embedding = self.embedding_model.embed_query(q)
  418. results = self.vector_store.similarity_search(q_embedding, k)
  419. for doc, score in results:
  420. doc_id = doc.page_content[:100]
  421. if doc_id not in all_docs or score > all_docs[doc_id][1]:
  422. all_docs[doc_id] = (doc, score)
  423. sorted_docs = sorted(all_docs.values(), key=lambda x: x[1], reverse=True)
  424. return sorted_docs[:k]
  425. def generate(self, query, context):
  426. prompt = (
  427. "根据以下参考文档,回答问题。\n\n"
  428. "参考文档:\n{}\n\n"
  429. "问题:{}\n\n"
  430. "请综合多个查询角度的检索结果,详细回答。".format(context, query)
  431. )
  432. return self._call_llm(prompt)
  433. class HyDERAG(BaseRAG):
  434. def __init__(self, num_hypotheses=3, **kwargs):
  435. super().__init__(**kwargs)
  436. self.num_hypotheses = num_hypotheses
  437. def _generate_hypothetical_docs(self, query):
  438. prompt = (
  439. "假设你是一个招投标专家,请根据以下问题,"
  440. "生成{}个可能包含答案的假想文档段落。"
  441. "每个段落用'<doc>'和'</doc>'分隔。\n\n"
  442. "问题:{}\n\n"
  443. "假想文档:".format(self.num_hypotheses, query)
  444. )
  445. response = self._call_llm(prompt)
  446. docs = re.findall(r'<doc>(.*?)</doc>', response, re.DOTALL)
  447. if not docs:
  448. docs = [response]
  449. return docs[:self.num_hypotheses]
  450. def retrieve(self, query, k=10):
  451. hypothetical_docs = self._generate_hypothetical_docs(query)
  452. all_results = {}
  453. for hypo_doc in hypothetical_docs:
  454. hypo_embedding = self.embedding_model.embed_query(hypo_doc)
  455. results = self.vector_store.similarity_search(hypo_embedding, k)
  456. for doc, score in results:
  457. doc_id = doc.page_content[:100]
  458. if doc_id not in all_results or score > all_results[doc_id][1]:
  459. all_results[doc_id] = (doc, score)
  460. query_embedding = self.embedding_model.embed_query(query)
  461. direct_results = self.vector_store.similarity_search(query_embedding, k)
  462. for doc, score in direct_results:
  463. doc_id = doc.page_content[:100]
  464. if doc_id not in all_results or score > all_results[doc_id][1]:
  465. all_results[doc_id] = (doc, score)
  466. sorted_docs = sorted(all_results.values(), key=lambda x: x[1], reverse=True)
  467. return sorted_docs[:k]
  468. def generate(self, query, context):
  469. prompt = (
  470. "根据以下参考文档,回答问题。\n\n"
  471. "参考文档:\n{}\n\n"
  472. "问题:{}\n\n"
  473. "请详细回答。".format(context, query)
  474. )
  475. return self._call_llm(prompt)
  476. class StepBackRAG(BaseRAG):
  477. def __init__(self, **kwargs):
  478. super().__init__(**kwargs)
  479. def _generate_step_back_query(self, query):
  480. prompt = (
  481. "请将以下具体问题抽象为一个更通用的高层次问题,"
  482. "用于查找相关背景知识。只返回抽象后的问题。\n\n"
  483. "具体问题:{}\n\n"
  484. "高层次问题:".format(query)
  485. )
  486. return self._call_llm(prompt).strip()
  487. def retrieve(self, query, k=10):
  488. step_back_query = self._generate_step_back_query(query)
  489. logger.info("StepBack query: {}".format(step_back_query))
  490. query_embedding = self.embedding_model.embed_query(query)
  491. direct_results = self.vector_store.similarity_search(query_embedding, k // 2)
  492. step_back_embedding = self.embedding_model.embed_query(step_back_query)
  493. step_back_results = self.vector_store.similarity_search(step_back_embedding, k // 2)
  494. all_docs = {}
  495. for doc, score in direct_results + step_back_results:
  496. doc_id = doc.page_content[:100]
  497. if doc_id not in all_docs or score > all_docs[doc_id][1]:
  498. all_docs[doc_id] = (doc, score)
  499. sorted_docs = sorted(all_docs.values(), key=lambda x: x[1], reverse=True)
  500. return sorted_docs[:k]
  501. def generate(self, query, context):
  502. prompt = (
  503. "结合以下参考文档(包含具体问题和抽象问题的检索结果),"
  504. "回答问题。\n\n"
  505. "参考文档:\n{}\n\n"
  506. "问题:{}\n\n"
  507. "请详细回答。".format(context, query)
  508. )
  509. return self._call_llm(prompt)
  510. class ContextualCompressionRAG(BaseRAG):
  511. def __init__(self, compression_llm=None, **kwargs):
  512. super().__init__(**kwargs)
  513. self.compression_llm = compression_llm or llm_client
  514. def _compress_doc(self, query, doc):
  515. prompt = (
  516. "请压缩以下文档片段,仅保留与问题相关的信息。\n\n"
  517. "问题:{}\n\n"
  518. "文档:{}\n\n"
  519. "压缩后的内容:".format(query, doc.page_content)
  520. )
  521. messages = []
  522. messages.append({"role": "user", "content": prompt})
  523. response = self.compression_llm.chat.completions.create(
  524. model=self.llm_model,
  525. messages=messages,
  526. temperature=0.0,
  527. max_tokens=512,
  528. )
  529. return response.choices[0].message.content
  530. def retrieve(self, query, k=10):
  531. query_embedding = self.embedding_model.embed_query(query)
  532. initial_results = self.vector_store.similarity_search(query_embedding, k * 2)
  533. compressed_results = []
  534. for doc, score in initial_results:
  535. compressed_text = self._compress_doc(query, doc)
  536. compressed_doc = Document(page_content=compressed_text, metadata=doc.metadata)
  537. compressed_results.append((compressed_doc, score))
  538. return compressed_results[:k]
  539. def generate(self, query, context):
  540. prompt = (
  541. "根据以下经过上下文压缩的参考文档,回答问题。\n\n"
  542. "参考文档:\n{}\n\n"
  543. "问题:{}\n\n"
  544. "请详细回答。".format(context, query)
  545. )
  546. return self._call_llm(prompt)
  547. class SelfRAG(BaseRAG):
  548. def __init__(self, relevance_threshold=0.5, support_threshold=0.5, retrieval_threshold=0.6, critic_model=None, **kwargs):
  549. super().__init__(**kwargs)
  550. self.relevance_threshold = relevance_threshold
  551. self.support_threshold = support_threshold
  552. self.retrieval_threshold = retrieval_threshold
  553. self.critic_model = critic_model
  554. def _is_retrieval_needed(self, query):
  555. prompt = (
  556. "判断以下问题是否需要检索外部知识才能回答。"
  557. "只需回答'是'或'否'。\n\n"
  558. "问题:{}\n\n"
  559. "是否需要检索:".format(query)
  560. )
  561. response = self._call_llm(prompt).strip().lower()
  562. return "是" in response or "yes" in response
  563. def _evaluate_relevance(self, query, doc):
  564. prompt = (
  565. "评估以下文档片段与问题的相关性,给出0-1之间的分数。\n\n"
  566. "问题:{}\n\n"
  567. "文档:{}\n\n"
  568. "相关性分数:".format(query, doc.page_content)
  569. )
  570. response = self._call_llm(prompt).strip()
  571. try:
  572. return float(response)
  573. except:
  574. return 0.5
  575. def _evaluate_support(self, query, answer, doc):
  576. prompt = (
  577. "评估以下回答是否得到了文档的支持,给出0-1之间的分数。\n\n"
  578. "文档:{}\n\n"
  579. "回答:{}\n\n"
  580. "支持分数:".format(doc.page_content, answer)
  581. )
  582. response = self._call_llm(prompt).strip()
  583. try:
  584. return float(response)
  585. except:
  586. return 0.5
  587. def _evaluate_usefulness(self, query, answer):
  588. prompt = (
  589. "评估以下回答对于问题的有用性,给出0-1之间的分数。\n\n"
  590. "问题:{}\n\n"
  591. "回答:{}\n\n"
  592. "有用性分数:".format(query, answer)
  593. )
  594. response = self._call_llm(prompt).strip()
  595. try:
  596. return float(response)
  597. except:
  598. return 0.5
  599. def retrieve(self, query, k=10):
  600. if not self._is_retrieval_needed(query):
  601. return []
  602. query_embedding = self.embedding_model.embed_query(query)
  603. initial_results = self.vector_store.similarity_search(query_embedding, k * 2)
  604. filtered_results = []
  605. for doc, score in initial_results:
  606. relevance = self._evaluate_relevance(query, doc)
  607. if relevance >= self.relevance_threshold:
  608. filtered_results.append((doc, relevance * score))
  609. filtered_results.sort(key=lambda x: x[1], reverse=True)
  610. return filtered_results[:k]
  611. def generate(self, query, context):
  612. prompt = (
  613. "根据以下经过自反思筛选的参考文档,回答问题。\n\n"
  614. "参考文档:\n{}\n\n"
  615. "问题:{}\n\n"
  616. "请确保回答有充分的文档支持,详细回答。".format(context, query)
  617. )
  618. return self._call_llm(prompt)
  619. class CorrectiveRAG(BaseRAG):
  620. def __init__(self, correctness_threshold=0.6, **kwargs):
  621. super().__init__(**kwargs)
  622. self.correctness_threshold = correctness_threshold
  623. self.web_search_results = []
  624. def _evaluate_correctness(self, query, docs):
  625. if not docs:
  626. return 0.0
  627. context = "\n\n".join([doc.page_content for doc, _ in docs[:5]])
  628. prompt = (
  629. "评估以下文档集合是否能够正确回答问题,给出0-1之间的分数。\n\n"
  630. "问题:{}\n\n"
  631. "文档集合:\n{}\n\n"
  632. "正确性分数:".format(query, context)
  633. )
  634. response = self._call_llm(prompt).strip()
  635. try:
  636. return float(response)
  637. except:
  638. return 0.5
  639. def _web_search(self, query):
  640. return "[Web search results for: {}] - Simulated external knowledge".format(query)
  641. def retrieve(self, query, k=10):
  642. query_embedding = self.embedding_model.embed_query(query)
  643. initial_results = self.vector_store.similarity_search(query_embedding, k)
  644. correctness_score = self._evaluate_correctness(query, initial_results)
  645. if correctness_score < self.correctness_threshold:
  646. logger.info("CRAG: Correctness score {} below threshold, adding web search".format(correctness_score))
  647. web_result = self._web_search(query)
  648. web_doc = Document(
  649. page_content=web_result,
  650. metadata={"source": "web_search"}
  651. )
  652. initial_results.insert(0, (web_doc, 0.8))
  653. return initial_results
  654. def generate(self, query, context):
  655. prompt = (
  656. "根据以下参考文档(可能包含外部搜索结果),回答问题。\n\n"
  657. "参考文档:\n{}\n\n"
  658. "问题:{}\n\n"
  659. "请详细回答。如果包含外部搜索结果,请注明。".format(context, query)
  660. )
  661. return self._call_llm(prompt)
  662. class FLARERAG(BaseRAG):
  663. def __init__(self, max_iterations=3, **kwargs):
  664. super().__init__(**kwargs)
  665. self.max_iterations = max_iterations
  666. def _need_more_retrieval(self, query, current_answer):
  667. prompt = (
  668. "基于当前已有的信息,判断是否还需要更多检索才能完整回答问题。\n\n"
  669. "问题:{}\n\n"
  670. "当前已有信息的回答:{}\n\n"
  671. "是否需要更多检索(只回答是/否):".format(query, current_answer)
  672. )
  673. response = self._call_llm(prompt).strip().lower()
  674. return "是" in response or "yes" in response
  675. def _generate_next_query(self, query, current_answer):
  676. prompt = (
  677. "基于当前回答的不足,生成一个新的查询来补充信息。\n\n"
  678. "原问题:{}\n\n"
  679. "当前回答:{}\n\n"
  680. "新查询:".format(query, current_answer)
  681. )
  682. return self._call_llm(prompt).strip()
  683. def retrieve(self, query, k=10):
  684. all_docs = {}
  685. current_query = query
  686. current_answer = ""
  687. for iteration in range(self.max_iterations):
  688. query_embedding = self.embedding_model.embed_query(current_query)
  689. results = self.vector_store.similarity_search(query_embedding, k)
  690. for doc, score in results:
  691. doc_id = doc.page_content[:100]
  692. if doc_id not in all_docs or score > all_docs[doc_id][1]:
  693. all_docs[doc_id] = (doc, score)
  694. context = self._format_context(list(all_docs.values()))
  695. current_answer = self._generate_partial_answer(query, context)
  696. if not self._need_more_retrieval(query, current_answer):
  697. break
  698. current_query = self._generate_next_query(query, current_answer)
  699. logger.info("FLARE iteration {}, new query: {}".format(iteration + 1, current_query))
  700. sorted_docs = sorted(all_docs.values(), key=lambda x: x[1], reverse=True)
  701. return sorted_docs[:k]
  702. def _generate_partial_answer(self, query, context):
  703. prompt = (
  704. "根据以下参考文档,给出问题的回答。\n\n"
  705. "参考文档:\n{}\n\n"
  706. "问题:{}\n\n"
  707. "回答:".format(context, query)
  708. )
  709. return self._call_llm(prompt)
  710. def generate(self, query, context):
  711. return self._generate_partial_answer(query, context)
  712. class RAPTORRAG(BaseRAG):
  713. def __init__(self, max_clusters=50, summary_length=256, num_tree_levels=2, **kwargs):
  714. super().__init__(**kwargs)
  715. self.max_clusters = max_clusters
  716. self.summary_length = summary_length
  717. self.num_tree_levels = num_tree_levels
  718. self.hierarchical_docs = []
  719. def _summarize_texts(self, texts):
  720. summaries = []
  721. batch_size = 5
  722. for i in range(0, len(texts), batch_size):
  723. batch = texts[i:i + batch_size]
  724. combined = "\n\n".join(batch)
  725. prompt = (
  726. "请用一段话总结以下内容的核心要点({}字以内):\n\n"
  727. "{}\n\n"
  728. "总结:".format(self.summary_length, combined)
  729. )
  730. summary = self._call_llm(prompt)
  731. summaries.append(summary)
  732. return summaries
  733. def _cluster_embeddings(self, embeddings):
  734. from sklearn.cluster import AgglomerativeClustering
  735. emb_np = np.array(embeddings)
  736. n_clusters = min(self.max_clusters, len(embeddings))
  737. if n_clusters < 2:
  738. return [list(range(len(embeddings)))]
  739. clustering = AgglomerativeClustering(n_clusters=n_clusters)
  740. labels = clustering.fit_predict(emb_np)
  741. clusters = {}
  742. for i, label in enumerate(labels):
  743. if label not in clusters:
  744. clusters[label] = []
  745. clusters[label].append(i)
  746. return list(clusters.values())
  747. def build_tree(self, documents):
  748. self.index_documents(documents)
  749. current_level_docs = documents
  750. for level in range(self.num_tree_levels):
  751. texts = [doc.page_content for doc in current_level_docs]
  752. embeddings = self.embedding_model.embed_documents(texts)
  753. clusters = self._cluster_embeddings(embeddings)
  754. summary_docs = []
  755. for cluster in clusters:
  756. cluster_texts = [texts[i] for i in cluster]
  757. summaries = self._summarize_texts(cluster_texts)
  758. for summary in summaries:
  759. summary_docs.append(Document(
  760. page_content=summary,
  761. metadata={"level": level + 1, "num_source_docs": len(cluster)}
  762. ))
  763. self.hierarchical_docs.extend(summary_docs)
  764. if summary_docs:
  765. summary_embeddings = self.embedding_model.embed_documents([d.page_content for d in summary_docs])
  766. self.vector_store.add_documents(summary_docs, summary_embeddings)
  767. current_level_docs = summary_docs
  768. logger.info("RAPTOR level {}: {} summaries created".format(level + 1, len(summary_docs)))
  769. def retrieve(self, query, k=10):
  770. query_embedding = self.embedding_model.embed_query(query)
  771. all_results = []
  772. all_results.extend(self.vector_store.similarity_search(query_embedding, k))
  773. sorted_results = sorted(all_results, key=lambda x: x[1], reverse=True)
  774. return sorted_results[:k]
  775. def generate(self, query, context):
  776. prompt = (
  777. "根据以下参考文档(包含多层次摘要信息),回答问题。\n\n"
  778. "参考文档:\n{}\n\n"
  779. "问题:{}\n\n"
  780. "请详细回答。".format(context, query)
  781. )
  782. return self._call_llm(prompt)
  783. class EnsembleRAG(BaseRAG):
  784. def __init__(self, methods=None, **kwargs):
  785. super().__init__(**kwargs)
  786. self.methods = methods or ["naive", "hybrid", "multi_query"]
  787. def retrieve(self, query, k=10):
  788. all_docs = {}
  789. query_embedding = self.embedding_model.embed_query(query)
  790. base_results = self.vector_store.similarity_search(query_embedding, k * 2)
  791. for doc, score in base_results:
  792. doc_id = doc.page_content[:100]
  793. all_docs[doc_id] = (doc, score)
  794. try:
  795. from rank_bm25 import BM25Okapi
  796. all_texts = [doc.page_content.split() for doc, _ in base_results]
  797. bm25 = BM25Okapi(all_texts)
  798. query_tokens = query.split()
  799. bm25_scores = bm25.get_scores(query_tokens)
  800. max_bm25 = max(bm25_scores) if bm25_scores.max() > 0 else 1.0
  801. for i, (doc, _) in enumerate(base_results):
  802. doc_id = doc.page_content[:100]
  803. bm25_score = bm25_scores[i] / max_bm25
  804. all_docs[doc_id] = (doc, all_docs[doc_id][1] * 0.5 + bm25_score * 0.5)
  805. except:
  806. pass
  807. sorted_docs = sorted(all_docs.values(), key=lambda x: x[1], reverse=True)
  808. return sorted_docs[:k]
  809. def generate(self, query, context):
  810. prompt = (
  811. "根据以下参考文档(通过集成多种检索方法获取),回答问题。\n\n"
  812. "参考文档:\n{}\n\n"
  813. "问题:{}\n\n"
  814. "请综合多种检索策略的结果,详细回答。".format(context, query)
  815. )
  816. return self._call_llm(prompt)
  817. class BidFieldExtractionRAG(BaseRAG):
  818. def __init__(self, fields=None, **kwargs):
  819. super().__init__(**kwargs)
  820. self.fields = fields or [
  821. "project_name", "project_code", "budget_amount", "currency",
  822. "bid_deadline", "bid_open_time", "bid_location",
  823. "purchaser_name", "purchaser_contact", "purchaser_phone",
  824. "agency_name", "agency_contact", "agency_phone",
  825. "qualification_requirements", "bid_bond_amount",
  826. "performance_bond_amount", "warranty_period",
  827. "delivery_time", "delivery_location", "payment_terms",
  828. "evaluation_method", "scope_of_work"
  829. ]
  830. def retrieve(self, query, k=10):
  831. all_docs = {}
  832. for field in self.fields:
  833. field_query = "{} {}".format(query, field)
  834. field_embedding = self.embedding_model.embed_query(field_query)
  835. results = self.vector_store.similarity_search(field_embedding, k // 3)
  836. for doc, score in results:
  837. doc_id = doc.page_content[:100]
  838. if doc_id not in all_docs or score > all_docs[doc_id][1]:
  839. all_docs[doc_id] = (doc, score)
  840. query_embedding = self.embedding_model.embed_query(query)
  841. direct_results = self.vector_store.similarity_search(query_embedding, k)
  842. for doc, score in direct_results:
  843. doc_id = doc.page_content[:100]
  844. if doc_id not in all_docs or score > all_docs[doc_id][1]:
  845. all_docs[doc_id] = (doc, score)
  846. sorted_docs = sorted(all_docs.values(), key=lambda x: x[1], reverse=True)
  847. return sorted_docs[:k]
  848. def generate(self, query, context):
  849. fields_str = "\n".join(["- {}".format(f) for f in self.fields])
  850. prompt = """你是一个招投标领域的专家。请根据提供的文档内容,提取以下字段信息:
  851. {}
  852. 文档内容:
  853. {}
  854. 请以JSON格式返回提取结果。如果某个字段无法从文档中提取,请返回null。
  855. JSON格式示例:
  856. {{
  857. "project_name": "项目名称",
  858. "budget_amount": 1000000,
  859. ...
  860. }}""".format(fields_str, context)
  861. return self._call_llm(prompt)
  862. class TableAwareRAG(BaseRAG):
  863. def __init__(self, **kwargs):
  864. super().__init__(**kwargs)
  865. def _extract_table_info(self, doc):
  866. content = doc.page_content
  867. table_pattern = r'(\|.*\|[\r\n]+)'
  868. tables = re.findall(table_pattern, content)
  869. if tables:
  870. return "[TABLE DETECTED]\n{}".format(''.join(tables))
  871. numbered_pattern = r'(\d+[\.、].*?(?=\d+[\.、]|$))'
  872. numbered_items = re.findall(numbered_pattern, content, re.DOTALL)
  873. if numbered_items:
  874. return "[STRUCTURED LIST DETECTED]\n{}".format(''.join(numbered_items))
  875. return content
  876. def retrieve(self, query, k=10):
  877. query_embedding = self.embedding_model.embed_query(query)
  878. initial_results = self.vector_store.similarity_search(query_embedding, k * 2)
  879. enhanced_results = []
  880. for doc, score in initial_results:
  881. enhanced_content = self._extract_table_info(doc)
  882. enhanced_doc = Document(
  883. page_content=enhanced_content,
  884. metadata=dict(doc.metadata, has_table="[TABLE" in enhanced_content)
  885. )
  886. table_bonus = 0.1 if "[TABLE" in enhanced_content else 0.0
  887. enhanced_results.append((enhanced_doc, score + table_bonus))
  888. enhanced_results.sort(key=lambda x: x[1], reverse=True)
  889. return enhanced_results[:k]
  890. def generate(self, query, context):
  891. prompt = (
  892. "根据以下参考文档(包含表格和结构化数据),回答问题。\n\n"
  893. "参考文档:\n{}\n\n"
  894. "问题:{}\n\n"
  895. "请特别注意表格和结构化数据中的信息,详细回答。".format(context, query)
  896. )
  897. return self._call_llm(prompt)
  898. class GraphRAG(BaseRAG):
  899. def __init__(self, **kwargs):
  900. super().__init__(**kwargs)
  901. self.graph = {}
  902. def _extract_entities(self, text):
  903. entities = re.findall(r'[一-龥]{2,10}(?:公司|单位|招标|投标|项目|金额|时间)', text)
  904. return list(set(entities))
  905. def build_graph(self, documents):
  906. self.index_documents(documents)
  907. for doc in documents:
  908. entities = self._extract_entities(doc.page_content)
  909. for entity in entities:
  910. if entity not in self.graph:
  911. self.graph[entity] = []
  912. self.graph[entity].append(doc)
  913. logger.info("Graph built with {} entities".format(len(self.graph)))
  914. def retrieve(self, query, k=10):
  915. query_embedding = self.embedding_model.embed_query(query)
  916. semantic_results = self.vector_store.similarity_search(query_embedding, k)
  917. query_entities = self._extract_entities(query)
  918. graph_docs = {}
  919. for entity in query_entities:
  920. if entity in self.graph:
  921. for doc in self.graph[entity]:
  922. doc_id = doc.page_content[:100]
  923. if doc_id not in graph_docs:
  924. graph_docs[doc_id] = (doc, 0.7)
  925. else:
  926. graph_docs[doc_id] = (doc, graph_docs[doc_id][1] + 0.1)
  927. all_docs = {}
  928. for doc, score in semantic_results:
  929. doc_id = doc.page_content[:100]
  930. all_docs[doc_id] = (doc, score)
  931. for doc_id, (doc, score) in graph_docs.items():
  932. if doc_id in all_docs:
  933. all_docs[doc_id] = (doc, all_docs[doc_id][1] * 0.6 + score * 0.4)
  934. else:
  935. all_docs[doc_id] = (doc, score)
  936. sorted_docs = sorted(all_docs.values(), key=lambda x: x[1], reverse=True)
  937. return sorted_docs[:k]
  938. def generate(self, query, context):
  939. prompt = (
  940. "根据以下参考文档(结合知识图谱和语义检索),回答问题。\n\n"
  941. "参考文档:\n{}\n\n"
  942. "问题:{}\n\n"
  943. "请详细回答。".format(context, query)
  944. )
  945. return self._call_llm(prompt)
  946. class BM25RAG(BaseRAG):
  947. def __init__(self, llm_client=None, llm_model="gpt-4o", **kwargs):
  948. self.embedding_model = None
  949. self.vector_store = None
  950. self.llm_client = llm_client
  951. self.llm_model = llm_model
  952. self.name = self.__class__.__name__
  953. from rank_bm25 import BM25Okapi
  954. self.bm25 = None
  955. self._all_texts = []
  956. self._all_documents = []
  957. def index_documents(self, documents):
  958. self._all_documents = documents
  959. self._all_texts = [doc.page_content.split() for doc in documents]
  960. from rank_bm25 import BM25Okapi
  961. self.bm25 = BM25Okapi(self._all_texts)
  962. logger.info("BM25 index built with {} documents".format(len(documents)))
  963. def retrieve(self, query, k=10):
  964. if self.bm25 is None:
  965. return []
  966. query_tokens = query.split()
  967. scores = self.bm25.get_scores(query_tokens)
  968. scored_docs = []
  969. for i, score in enumerate(scores):
  970. if score > 0:
  971. scored_docs.append((self._all_documents[i], float(score)))
  972. scored_docs.sort(key=lambda x: x[1], reverse=True)
  973. return scored_docs[:k]
  974. def generate(self, query, context):
  975. prompt = (
  976. "根据以下参考文档(BM25关键词检索),回答问题。\n\n"
  977. "参考文档:\n{}\n\n"
  978. "问题:{}\n\n"
  979. "请详细回答。".format(context, query)
  980. )
  981. return self._call_llm(prompt)
  982. def query(self, query, k=10):
  983. start_total = time.time()
  984. t0 = time.time()
  985. docs = self.retrieve(query, k)
  986. retrieval_time = time.time() - t0
  987. context = self._format_context(docs)
  988. t1 = time.time()
  989. answer = self.generate(query, context)
  990. generation_time = time.time() - t1
  991. total_time = time.time() - start_total
  992. return RAGResult(
  993. answer=answer,
  994. retrieved_docs=docs,
  995. latency_retrieval=retrieval_time,
  996. latency_generation=generation_time,
  997. latency_total=total_time,
  998. metadata={"method": self.name, "num_context_tokens": len(context)},
  999. )
  1000. class TFIDFRAG(BaseRAG):
  1001. def __init__(self, llm_client=None, llm_model="gpt-4o", **kwargs):
  1002. self.embedding_model = None
  1003. self.vector_store = None
  1004. self.llm_client = llm_client
  1005. self.llm_model = llm_model
  1006. self.name = self.__class__.__name__
  1007. self.vectorizer = None
  1008. self.tfidf_matrix = None
  1009. self._all_documents = []
  1010. def index_documents(self, documents):
  1011. self._all_documents = documents
  1012. texts = [doc.page_content for doc in documents]
  1013. from sklearn.feature_extraction.text import TfidfVectorizer
  1014. self.vectorizer = TfidfVectorizer()
  1015. self.tfidf_matrix = self.vectorizer.fit_transform(texts)
  1016. logger.info("TF-IDF index built with {} documents, vocab size: {}".format(len(documents), len(self.vectorizer.vocabulary_)))
  1017. def retrieve(self, query, k=10):
  1018. if self.tfidf_matrix is None:
  1019. return []
  1020. from sklearn.metrics.pairwise import cosine_similarity
  1021. query_vec = self.vectorizer.transform([query])
  1022. scores = cosine_similarity(query_vec, self.tfidf_matrix).flatten()
  1023. k = min(k, len(scores))
  1024. top_indices = np.argsort(scores)[::-1][:k]
  1025. results = []
  1026. for idx in top_indices:
  1027. if scores[idx] > 0:
  1028. results.append((self._all_documents[idx], float(scores[idx])))
  1029. return results
  1030. def generate(self, query, context):
  1031. prompt = (
  1032. "根据以下参考文档(TF-IDF关键词检索),回答问题。\n\n"
  1033. "参考文档:\n{}\n\n"
  1034. "问题:{}\n\n"
  1035. "请详细回答。".format(context, query)
  1036. )
  1037. return self._call_llm(prompt)
  1038. def query(self, query, k=10):
  1039. start_total = time.time()
  1040. t0 = time.time()
  1041. docs = self.retrieve(query, k)
  1042. retrieval_time = time.time() - t0
  1043. context = self._format_context(docs)
  1044. t1 = time.time()
  1045. answer = self.generate(query, context)
  1046. generation_time = time.time() - t1
  1047. total_time = time.time() - start_total
  1048. return RAGResult(
  1049. answer=answer,
  1050. retrieved_docs=docs,
  1051. latency_retrieval=retrieval_time,
  1052. latency_generation=generation_time,
  1053. latency_total=total_time,
  1054. metadata={"method": self.name, "num_context_tokens": len(context)},
  1055. )
  1056. class KeywordRAG(BaseRAG):
  1057. def __init__(self, search_method="bm25", llm_client=None, llm_model="gpt-4o", **kwargs):
  1058. self.embedding_model = None
  1059. self.vector_store = None
  1060. self.llm_client = llm_client
  1061. self.llm_model = llm_model
  1062. self.name = self.__class__.__name__
  1063. self.search_method = search_method
  1064. self._all_documents = []
  1065. self._all_texts = []
  1066. self.bm25 = None
  1067. self.tfidf_matrix = None
  1068. self.vectorizer = None
  1069. def index_documents(self, documents):
  1070. self._all_documents = documents
  1071. self._all_texts = [doc.page_content for doc in documents]
  1072. if self.search_method == "bm25":
  1073. from rank_bm25 import BM25Okapi
  1074. tokenized = [t.split() for t in self._all_texts]
  1075. self.bm25 = BM25Okapi(tokenized)
  1076. elif self.search_method == "tfidf":
  1077. from sklearn.feature_extraction.text import TfidfVectorizer
  1078. self.vectorizer = TfidfVectorizer()
  1079. self.tfidf_matrix = self.vectorizer.fit_transform(self._all_texts)
  1080. logger.info("KeywordRAG ({}) index built with {} documents".format(self.search_method, len(documents)))
  1081. def retrieve(self, query, k=10):
  1082. if self.search_method == "bm25":
  1083. query_tokens = query.split()
  1084. scores = self.bm25.get_scores(query_tokens)
  1085. elif self.search_method == "tfidf":
  1086. from sklearn.metrics.pairwise import cosine_similarity
  1087. query_vec = self.vectorizer.transform([query])
  1088. scores = cosine_similarity(query_vec, self.tfidf_matrix).flatten()
  1089. else:
  1090. return []
  1091. k = min(k, len(scores))
  1092. top_indices = np.argsort(scores)[::-1][:k]
  1093. results = []
  1094. for idx in top_indices:
  1095. if scores[idx] > 0:
  1096. results.append((self._all_documents[idx], float(scores[idx])))
  1097. return results
  1098. def generate(self, query, context):
  1099. prompt = (
  1100. "根据以下参考文档(关键词检索:{}),回答问题。\n\n"
  1101. "参考文档:\n{}\n\n"
  1102. "问题:{}\n\n"
  1103. "请详细回答。".format(self.search_method, context, query)
  1104. )
  1105. return self._call_llm(prompt)
  1106. def query(self, query, k=10):
  1107. start_total = time.time()
  1108. t0 = time.time()
  1109. docs = self.retrieve(query, k)
  1110. retrieval_time = time.time() - t0
  1111. context = self._format_context(docs)
  1112. t1 = time.time()
  1113. answer = self.generate(query, context)
  1114. generation_time = time.time() - t1
  1115. total_time = time.time() - start_total
  1116. return RAGResult(
  1117. answer=answer,
  1118. retrieved_docs=docs,
  1119. latency_retrieval=retrieval_time,
  1120. latency_generation=generation_time,
  1121. latency_total=total_time,
  1122. metadata={"method": self.name, "num_context_tokens": len(context)},
  1123. )