benchmark_all_methods.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371
  1. """
  2. Complete benchmark script for comparing all RAG methods
  3. This script demonstrates various RAG methods and compares their performance
  4. """
  5. import sys
  6. import os
  7. sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
  8. from openai import OpenAI
  9. from FlagEmbedding import FlagReranker
  10. from bdirag.document_processor import Document, DocumentProcessor
  11. from bdirag.embedding_models import SentenceTransformerEmbedding
  12. from bdirag.vector_stores import FAISSStore
  13. from bdirag.rag_methods import (
  14. NaiveRAG,
  15. RerankRAG,
  16. HybridSearchRAG,
  17. MultiQueryRAG,
  18. HyDERAG,
  19. SelfRAG,
  20. CorrectiveRAG,
  21. FLARERAG,
  22. RAPTORRAG,
  23. BidFieldExtractionRAG,
  24. TableAwareRAG,
  25. EnsembleRAG,
  26. GraphRAG,
  27. StepBackRAG,
  28. ContextualCompressionRAG,
  29. BM25RAG,
  30. TFIDFRAG,
  31. KeywordRAG,
  32. )
  33. from bdirag.benchmark import RAGBenchmark
  34. from examples.sample_data import SAMPLE_BIDDING_DOCS
  35. def setup_rag_methods(embedding_model, vector_store, llm_client, llm_model="gpt-4o"):
  36. """Initialize all RAG methods with shared components"""
  37. rerank_model = FlagReranker("BAAI/bge-reranker-large", use_fp16=True)
  38. methods = {}
  39. # Basic RAG Methods
  40. methods["NaiveRAG"] = NaiveRAG(
  41. embedding_model=embedding_model,
  42. vector_store=vector_store,
  43. llm_client=llm_client,
  44. llm_model=llm_model,
  45. )
  46. methods["RerankRAG"] = RerankRAG(
  47. embedding_model=embedding_model,
  48. vector_store=vector_store,
  49. llm_client=llm_client,
  50. llm_model=llm_model,
  51. rerank_model=rerank_model,
  52. rerank_top_k=5,
  53. )
  54. methods["HybridSearchRAG"] = HybridSearchRAG(
  55. embedding_model=embedding_model,
  56. vector_store=vector_store,
  57. llm_client=llm_client,
  58. llm_model=llm_model,
  59. semantic_weight=0.5,
  60. )
  61. # Advanced RAG Methods
  62. methods["MultiQueryRAG"] = MultiQueryRAG(
  63. embedding_model=embedding_model,
  64. vector_store=vector_store,
  65. llm_client=llm_client,
  66. llm_model=llm_model,
  67. num_queries=3,
  68. )
  69. methods["HyDERAG"] = HyDERAG(
  70. embedding_model=embedding_model,
  71. vector_store=vector_store,
  72. llm_client=llm_client,
  73. llm_model=llm_model,
  74. num_hypotheses=3,
  75. )
  76. methods["SelfRAG"] = SelfRAG(
  77. embedding_model=embedding_model,
  78. vector_store=vector_store,
  79. llm_client=llm_client,
  80. llm_model=llm_model,
  81. relevance_threshold=0.5,
  82. support_threshold=0.5,
  83. )
  84. methods["CorrectiveRAG"] = CorrectiveRAG(
  85. embedding_model=embedding_model,
  86. vector_store=vector_store,
  87. llm_client=llm_client,
  88. llm_model=llm_model,
  89. correctness_threshold=0.6,
  90. )
  91. methods["FLARERAG"] = FLARERAG(
  92. embedding_model=embedding_model,
  93. vector_store=vector_store,
  94. llm_client=llm_client,
  95. llm_model=llm_model,
  96. max_iterations=3,
  97. )
  98. methods["RAPTORRAG"] = RAPTORRAG(
  99. embedding_model=embedding_model,
  100. vector_store=vector_store,
  101. llm_client=llm_client,
  102. llm_model=llm_model,
  103. max_clusters=10,
  104. summary_length=256,
  105. num_tree_levels=2,
  106. )
  107. methods["StepBackRAG"] = StepBackRAG(
  108. embedding_model=embedding_model,
  109. vector_store=vector_store,
  110. llm_client=llm_client,
  111. llm_model=llm_model,
  112. )
  113. methods["ContextualCompressionRAG"] = ContextualCompressionRAG(
  114. embedding_model=embedding_model,
  115. vector_store=vector_store,
  116. llm_client=llm_client,
  117. compression_llm=llm_client,
  118. llm_model=llm_model,
  119. )
  120. methods["EnsembleRAG"] = EnsembleRAG(
  121. embedding_model=embedding_model,
  122. vector_store=vector_store,
  123. llm_client=llm_client,
  124. llm_model=llm_model,
  125. methods=["naive", "hybrid", "multi_query"],
  126. )
  127. # Bidding-specific RAG Methods
  128. methods["BidFieldExtractionRAG"] = BidFieldExtractionRAG(
  129. embedding_model=embedding_model,
  130. vector_store=vector_store,
  131. llm_client=llm_client,
  132. llm_model=llm_model,
  133. )
  134. methods["TableAwareRAG"] = TableAwareRAG(
  135. embedding_model=embedding_model,
  136. vector_store=vector_store,
  137. llm_client=llm_client,
  138. llm_model=llm_model,
  139. )
  140. methods["GraphRAG"] = GraphRAG(
  141. embedding_model=embedding_model,
  142. vector_store=vector_store,
  143. llm_client=llm_client,
  144. llm_model=llm_model,
  145. )
  146. # Keyword-based RAG Methods (BM25 / TF-IDF)
  147. methods["BM25RAG"] = BM25RAG(
  148. embedding_model=embedding_model,
  149. vector_store=vector_store,
  150. llm_client=llm_client,
  151. llm_model=llm_model,
  152. )
  153. methods["TFIDFRAG"] = TFIDFRAG(
  154. embedding_model=embedding_model,
  155. vector_store=vector_store,
  156. llm_client=llm_client,
  157. llm_model=llm_model,
  158. )
  159. methods["KeywordRAG_BM25"] = KeywordRAG(
  160. embedding_model=embedding_model,
  161. vector_store=vector_store,
  162. llm_client=llm_client,
  163. llm_model=llm_model,
  164. search_method="bm25",
  165. )
  166. methods["KeywordRAG_TFIDF"] = KeywordRAG(
  167. embedding_model=embedding_model,
  168. vector_store=vector_store,
  169. llm_client=llm_client,
  170. llm_model=llm_model,
  171. search_method="tfidf",
  172. )
  173. return methods
  174. def main():
  175. print("=" * 60)
  176. print("BidiRAG - RAG Methods Benchmark for Bidding Domain")
  177. print("=" * 60)
  178. # Configuration
  179. LLM_API_KEY = os.getenv("OPENAI_API_KEY", "your-api-key-here")
  180. LLM_BASE_URL = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
  181. LLM_MODEL = os.getenv("LLM_MODEL", "gpt-4o")
  182. EMBEDDING_MODEL_NAME = os.getenv("EMBEDDING_MODEL", "BAAI/bge-large-zh-v1.5")
  183. # Step 1: Initialize embedding model
  184. print("\n[1/5] Loading embedding model...")
  185. embedding_model = SentenceTransformerEmbedding(
  186. model_name=EMBEDDING_MODEL_NAME,
  187. device="cpu"
  188. )
  189. print(" Embedding dimension: {0}.format(embedding_model.dimension)")
  190. # Step 2: Initialize vector store
  191. print("\n[2/5] Initializing vector store...")
  192. vector_store = FAISSStore(embedding_model=embedding_model)
  193. # Step 3: Process documents
  194. print("\n[3/5] Processing sample bidding documents...")
  195. documents = [
  196. Document(page_content=doc["content"], metadata={"title": doc["title"], "source": doc["title"]})
  197. for doc in SAMPLE_BIDDING_DOCS
  198. ]
  199. print(" Loaded {0} documents.format(len(documents))")
  200. # Step 4: Index documents
  201. print("\n[4/5] Indexing documents...")
  202. for method_name, method in setup_rag_methods(embedding_model, vector_store, None, LLM_MODEL).items():
  203. if method_name == "RAPTORRAG":
  204. method.build_tree(documents)
  205. elif method_name == "GraphRAG":
  206. method.build_graph(documents)
  207. else:
  208. method.index_documents(documents)
  209. print(f" Indexing complete")
  210. # Step 5: Initialize LLM client
  211. print("\n[5/5] Initializing LLM client...")
  212. llm_client = OpenAI(api_key=LLM_API_KEY, base_url=LLM_BASE_URL)
  213. # Setup RAG methods
  214. print("\nSetting up RAG methods...")
  215. methods = setup_rag_methods(embedding_model, vector_store, llm_client, LLM_MODEL)
  216. # Define test queries
  217. test_queries = [
  218. "XX City Smart Transportation Project budget and deadline?",
  219. "What are the qualification requirements for the hospital equipment procurement?",
  220. "What is the warranty period for the university network project?",
  221. "List all bid bond amounts in the announcements",
  222. "What evaluation methods are used across different projects?",
  223. "XX Road construction project payment terms?",
  224. "Environmental monitoring system equipment list?",
  225. "Which projects require Grade I qualification?",
  226. ]
  227. print("\nRunning benchmark with {0} queries across {1} methods....format(len(test_queries), len(methods))")
  228. print("=" * 60)
  229. # Run benchmark
  230. benchmark = RAGBenchmark()
  231. # Run methods one by one to avoid overwhelming the LLM API
  232. results_summary = {}
  233. for method_name, method in methods.items():
  234. print("\n{0}.format('=' * 60)")
  235. print("Testing: {0}.format(method_name)")
  236. print("{0}.format('=' * 60)")
  237. method_results = []
  238. for i, query in enumerate(test_queries):
  239. print("\n Query {0}/{1}: {2}.format(i + 1, len(test_queries), query)")
  240. try:
  241. result = method.query(query, k=5)
  242. method_results.append(result)
  243. print(" Answer: {0}....format(result.answer[:100])")
  244. print(" Total Latency: {0}s.format(result.latency_total:.3f)")
  245. print(" Retrieval: {0}s | Generation: {1}s.format(result.latency_retrieval:.3f, result.latency_generation:.3f)")
  246. except Exception as e:
  247. print(" ERROR: {0}.format(e)")
  248. results_summary[method_name] = method_results
  249. # Generate comparison report
  250. print("\n\n" + "=" * 60)
  251. print("BENCHMARK RESULTS")
  252. print("=" * 60)
  253. # Create metrics manually from results
  254. from bdirag.benchmark import BenchmarkMetrics, BenchmarkResult
  255. import numpy as np
  256. import datetime
  257. metrics = []
  258. for method_name, results in results_summary.items():
  259. if not results:
  260. continue
  261. total_latencies = [r.latency_total for r in results]
  262. retrieval_latencies = [r.latency_retrieval for r in results]
  263. generation_latencies = [r.latency_generation for r in results]
  264. num_docs = [len(r.retrieved_docs) for r in results]
  265. metric = BenchmarkMetrics(
  266. method_name=method_name,
  267. avg_latency_total=np.mean(total_latencies),
  268. avg_latency_retrieval=np.mean(retrieval_latencies),
  269. avg_latency_generation=np.mean(generation_latencies),
  270. avg_docs_retrieved=np.mean(num_docs),
  271. total_queries=len(results),
  272. latency_std=np.std(total_latencies),
  273. retrieval_std=np.std(retrieval_latencies),
  274. generation_std=np.std(generation_latencies),
  275. min_latency=np.min(total_latencies),
  276. max_latency=np.max(total_latencies),
  277. p50_latency=np.percentile(total_latencies, 50),
  278. p95_latency=np.percentile(total_latencies, 95),
  279. )
  280. metrics.append(metric)
  281. # Print comparison table
  282. print("\nLatency Comparison (sorted by average total latency):")
  283. print("-" * 100)
  284. print("{0} {1} {2} {3} {4} {5} {6} {7}.format('Method':<30, 'Avg Total':>10, 'Avg Retri':>10, 'Avg Gener':>10, 'P50':>8, 'P95':>8, 'Min':>8, 'Max':>8)")
  285. print("-" * 100)
  286. for m in sorted(metrics, key=lambda x: x.avg_latency_total):
  287. print(
  288. "{0} {1} {2} .format(m.method_name:<30, m.avg_latency_total:>10.3f, m.avg_latency_retrieval:>10.3f)"
  289. "{0} {1} {2} .format(m.avg_latency_generation:>10.3f, m.p50_latency:>8.3f, m.p95_latency:>8.3f)"
  290. "{0} {1}.format(m.min_latency:>8.3f, m.max_latency:>8.3f)"
  291. )
  292. print("-" * 100)
  293. print("\nFastest Method: {0}.format(min(metrics, key=lambda x: x.avg_latency_total).method_name)")
  294. print("Most Stable: {0}.format(min(metrics, key=lambda x: x.latency_std).method_name)")
  295. # Save results
  296. output_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "output")
  297. os.makedirs(output_dir, exist_ok=True)
  298. benchmark_result = BenchmarkResult(
  299. metrics=metrics,
  300. detailed_results={},
  301. timestamp=datetime.datetime.now().isoformat(),
  302. )
  303. benchmark_result.save(os.path.join(output_dir, "benchmark_results.json"))
  304. benchmark.generate_report(benchmark_result, os.path.join(output_dir, "benchmark_report.md"))
  305. try:
  306. benchmark.plot_comparison(
  307. metrics,
  308. save_path=os.path.join(output_dir, "benchmark_comparison.png"),
  309. show=False,
  310. )
  311. except Exception as e:
  312. print("\nNote: Could not generate plot: {0}.format(e)")
  313. print("\nResults saved to {0}.format(output_dir)")
  314. print("\nBenchmark complete!")
  315. if __name__ == "__main__":
  316. main()