benchmark_retrieval_speed.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. """
  2. Speed-focused benchmark script - compares retrieval speed without LLM generation
  3. Use this to quickly compare the performance of different retrieval methods
  4. """
  5. import sys
  6. import os
  7. import time
  8. import numpy as np
  9. sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
  10. from rich.console import Console
  11. from rich.table import Table
  12. from bdirag.document_processor import Document
  13. from bdirag.embedding_models import SentenceTransformerEmbedding
  14. from bdirag.vector_stores import FAISSStore
  15. from bdirag.rag_methods import (
  16. NaiveRAG,
  17. HybridSearchRAG,
  18. MultiQueryRAG,
  19. HyDERAG,
  20. SelfRAG,
  21. StepBackRAG,
  22. BidFieldExtractionRAG,
  23. TableAwareRAG,
  24. EnsembleRAG,
  25. GraphRAG,
  26. BM25RAG,
  27. TFIDFRAG,
  28. KeywordRAG,
  29. )
  30. from examples.sample_data import SAMPLE_BIDDING_DOCS
  31. class RetrievalBenchmark:
  32. def __init__(self):
  33. self.results = {}
  34. def run(self, methods, queries, iterations=3):
  35. for method_name, method in methods.items():
  36. print("\nBenchmarking {0}....format(method_name)")
  37. latencies = []
  38. for query in queries:
  39. query_latencies = []
  40. for _ in range(iterations):
  41. start = time.time()
  42. docs = method.retrieve(query, k=5)
  43. elapsed = time.time() - start
  44. query_latencies.append(elapsed)
  45. avg_latency = np.mean(query_latencies)
  46. latencies.append(avg_latency)
  47. print(" Query: {0}... -> {1}s.format(query[:50], avg_latency:.3f)")
  48. self.results[method_name] = {
  49. "latencies": latencies,
  50. "avg": np.mean(latencies),
  51. "std": np.std(latencies),
  52. "min": np.min(latencies),
  53. "max": np.max(latencies),
  54. "p50": np.percentile(latencies, 50),
  55. "p95": np.percentile(latencies, 95),
  56. }
  57. self.print_results()
  58. def print_results(self):
  59. console = Console()
  60. table = Table(title="Retrieval Speed Comparison")
  61. table.add_column("Method", style="cyan")
  62. table.add_column("Avg (s)", justify="right", style="green")
  63. table.add_column("Std (s)", justify="right", style="green")
  64. table.add_column("Min (s)", justify="right", style="yellow")
  65. table.add_column("Max (s)", justify="right", style="yellow")
  66. table.add_column("P50 (s)", justify="right", style="magenta")
  67. table.add_column("P95 (s)", justify="right", style="magenta")
  68. table.add_column("QPS", justify="right", style="blue")
  69. for name in sorted(self.results.keys(), key=lambda x: self.results[x]["avg"]):
  70. r = self.results[name]
  71. table.add_row(
  72. name,
  73. "{0}.format(r['avg']:.4f)",
  74. "{0}.format(r['std']:.4f)",
  75. "{0}.format(r['min']:.4f)",
  76. "{0}.format(r['max']:.4f)",
  77. "{0}.format(r['p50']:.4f)",
  78. "{0}.format(r['p95']:.4f)",
  79. "{0}.format(1/r['avg']:.1f)",
  80. )
  81. console.print(table)
  82. def main():
  83. print("=" * 60)
  84. print("BidiRAG - Retrieval Speed Benchmark")
  85. print("=" * 60)
  86. # Load embedding model
  87. print("\nLoading embedding model...")
  88. embedding_model = SentenceTransformerEmbedding(
  89. model_name="BAAI/bge-large-zh-v1.5",
  90. device="cpu"
  91. )
  92. # Create vector store
  93. print("Creating vector store...")
  94. vector_store = FAISSStore(embedding_model=embedding_model)
  95. # Prepare documents
  96. documents = [
  97. Document(page_content=doc["content"], metadata={"title": doc["title"], "source": doc["title"]})
  98. for doc in SAMPLE_BIDDING_DOCS
  99. ]
  100. # Initialize methods (no LLM needed for pure retrieval)
  101. print("Initializing retrieval methods...")
  102. methods = {
  103. "BM25RAG": BM25RAG(
  104. embedding_model=embedding_model,
  105. vector_store=vector_store,
  106. ),
  107. "TFIDFRAG": TFIDFRAG(
  108. embedding_model=embedding_model,
  109. vector_store=vector_store,
  110. ),
  111. "KeywordRAG_BM25": KeywordRAG(
  112. embedding_model=embedding_model,
  113. vector_store=vector_store,
  114. search_method="bm25",
  115. ),
  116. "KeywordRAG_TFIDF": KeywordRAG(
  117. embedding_model=embedding_model,
  118. vector_store=vector_store,
  119. search_method="tfidf",
  120. ),
  121. "NaiveRAG": NaiveRAG(
  122. embedding_model=embedding_model,
  123. vector_store=vector_store,
  124. ),
  125. "HybridSearchRAG": HybridSearchRAG(
  126. embedding_model=embedding_model,
  127. vector_store=vector_store,
  128. ),
  129. "BidFieldExtractionRAG": BidFieldExtractionRAG(
  130. embedding_model=embedding_model,
  131. vector_store=vector_store,
  132. ),
  133. "TableAwareRAG": TableAwareRAG(
  134. embedding_model=embedding_model,
  135. vector_store=vector_store,
  136. ),
  137. "EnsembleRAG": EnsembleRAG(
  138. embedding_model=embedding_model,
  139. vector_store=vector_store,
  140. ),
  141. }
  142. # Index documents
  143. print("Indexing documents...")
  144. for method in methods.values():
  145. method.index_documents(documents)
  146. # Test queries
  147. test_queries = [
  148. "What is the project budget?",
  149. "What are the qualification requirements?",
  150. "When is the bid deadline?",
  151. "What is the warranty period?",
  152. "What are the payment terms?",
  153. "What is the evaluation method?",
  154. "What equipment is needed?",
  155. "What is the delivery time?",
  156. ]
  157. # Run benchmark
  158. print("\nRunning benchmark with {0} queries (3 iterations each)....format(len(test_queries))")
  159. benchmark = RetrievalBenchmark()
  160. benchmark.run(methods, test_queries, iterations=3)
  161. print("\nBenchmark complete!")
  162. if __name__ == "__main__":
  163. main()