benchmark.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. import time
  2. import json
  3. import os
  4. from typing import List, Dict, Any, Optional
  5. from dataclasses import dataclass, asdict
  6. import numpy as np
  7. from loguru import logger
  8. from .rag_methods import BaseRAG, RAGResult
  9. @dataclass
  10. class BenchmarkMetrics:
  11. method_name: str
  12. avg_latency_total: float
  13. avg_latency_retrieval: float
  14. avg_latency_generation: float
  15. avg_docs_retrieved: float
  16. total_queries: int
  17. latency_std: float
  18. retrieval_std: float
  19. generation_std: float
  20. min_latency: float
  21. max_latency: float
  22. p50_latency: float
  23. p95_latency: float
  24. def to_dict(self):
  25. return asdict(self)
  26. @dataclass
  27. class BenchmarkResult:
  28. metrics: List[BenchmarkMetrics]
  29. detailed_results: Dict[str, List[RAGResult]]
  30. timestamp: str
  31. def to_dict(self):
  32. return {
  33. "metrics": [m.to_dict() for m in self.metrics],
  34. "timestamp": self.timestamp,
  35. }
  36. def save(self, path):
  37. os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
  38. with open(path, "w", encoding="utf-8") as f:
  39. json.dump(self.to_dict(), f, ensure_ascii=False, indent=2)
  40. logger.info("Benchmark results saved to {}".format(path))
  41. class RAGBenchmark:
  42. def __init__(self):
  43. self.results = {}
  44. def run_single_query(self, rag_method, query, k=10):
  45. result = rag_method.query(query, k)
  46. return result
  47. def run_benchmark(self, rag_methods, queries, k=10, verbose=True):
  48. self.results = {}
  49. for method in rag_methods:
  50. method_results = []
  51. if verbose:
  52. logger.info("Running benchmark for: {}".format(method.name))
  53. for i, query in enumerate(queries):
  54. if verbose:
  55. logger.info(" Query {}/{}: {}...".format(i + 1, len(queries), query[:50]))
  56. try:
  57. result = self.run_single_query(method, query, k)
  58. method_results.append(result)
  59. if verbose:
  60. logger.info(" Answer: {}...".format(result.answer[:80]))
  61. logger.info(" Latency: {:.3f}s".format(result.latency_total))
  62. except Exception as e:
  63. logger.error(" Error on query {} for {}: {}".format(i + 1, method.name, e))
  64. self.results[method.name] = method_results
  65. metrics = self._compute_metrics()
  66. import datetime
  67. benchmark_result = BenchmarkResult(
  68. metrics=metrics,
  69. detailed_results=self.results,
  70. timestamp=datetime.datetime.now().isoformat(),
  71. )
  72. if verbose:
  73. self._print_summary(metrics)
  74. return benchmark_result
  75. def _compute_metrics(self):
  76. metrics = []
  77. for method_name, results in self.results.items():
  78. if not results:
  79. continue
  80. total_latencies = [r.latency_total for r in results]
  81. retrieval_latencies = [r.latency_retrieval for r in results]
  82. generation_latencies = [r.latency_generation for r in results]
  83. num_docs = [len(r.retrieved_docs) for r in results]
  84. metric = BenchmarkMetrics(
  85. method_name=method_name,
  86. avg_latency_total=np.mean(total_latencies),
  87. avg_latency_retrieval=np.mean(retrieval_latencies),
  88. avg_latency_generation=np.mean(generation_latencies),
  89. avg_docs_retrieved=np.mean(num_docs),
  90. total_queries=len(results),
  91. latency_std=np.std(total_latencies),
  92. retrieval_std=np.std(retrieval_latencies),
  93. generation_std=np.std(generation_latencies),
  94. min_latency=np.min(total_latencies),
  95. max_latency=np.max(total_latencies),
  96. p50_latency=np.percentile(total_latencies, 50),
  97. p95_latency=np.percentile(total_latencies, 95),
  98. )
  99. metrics.append(metric)
  100. return metrics
  101. def _print_summary(self, metrics):
  102. from rich.console import Console
  103. from rich.table import Table
  104. console = Console()
  105. table = Table(title="RAG Methods Benchmark Results")
  106. table.add_column("Method", style="cyan")
  107. table.add_column("Avg Total(s)", justify="right", style="green")
  108. table.add_column("Avg Retrieval(s)", justify="right", style="green")
  109. table.add_column("Avg Generation(s)", justify="right", style="green")
  110. table.add_column("P50(s)", justify="right", style="yellow")
  111. table.add_column("P95(s)", justify="right", style="yellow")
  112. table.add_column("Min(s)", justify="right", style="magenta")
  113. table.add_column("Max(s)", justify="right", style="magenta")
  114. table.add_column("Avg Docs", justify="right", style="blue")
  115. for m in sorted(metrics, key=lambda x: x.avg_latency_total):
  116. table.add_row(
  117. m.method_name,
  118. "{:.3f}".format(m.avg_latency_total),
  119. "{:.3f}".format(m.avg_latency_retrieval),
  120. "{:.3f}".format(m.avg_latency_generation),
  121. "{:.3f}".format(m.p50_latency),
  122. "{:.3f}".format(m.p95_latency),
  123. "{:.3f}".format(m.min_latency),
  124. "{:.3f}".format(m.max_latency),
  125. "{:.1f}".format(m.avg_docs_retrieved),
  126. )
  127. console.print(table)
  128. def plot_comparison(self, metrics, save_path=None, show=True):
  129. import matplotlib.pyplot as plt
  130. import matplotlib
  131. matplotlib.rcParams["font.sans-serif"] = ["SimHei", "Arial Unicode MS"]
  132. matplotlib.rcParams["axes.unicode_minus"] = False
  133. methods = [m.method_name for m in metrics]
  134. avg_total = [m.avg_latency_total for m in metrics]
  135. avg_retrieval = [m.avg_latency_retrieval for m in metrics]
  136. avg_generation = [m.avg_latency_generation for m in metrics]
  137. p50 = [m.p50_latency for m in metrics]
  138. p95 = [m.p95_latency for m in metrics]
  139. fig, axes = plt.subplots(2, 2, figsize=(14, 10))
  140. fig.suptitle("RAG Methods Performance Comparison", fontsize=16, fontweight="bold")
  141. colors = plt.cm.Set3(np.linspace(0, 1, len(methods)))
  142. x = np.arange(len(methods))
  143. width = 0.25
  144. ax1 = axes[0, 0]
  145. bars1 = ax1.bar(x - width, avg_retrieval, width, label="Retrieval", color="#4CAF50", alpha=0.8)
  146. bars2 = ax1.bar(x, avg_generation, width, label="Generation", color="#2196F3", alpha=0.8)
  147. bars3 = ax1.bar(x + width, avg_total, width, label="Total", color="#FF9800", alpha=0.8)
  148. ax1.set_xlabel("Method")
  149. ax1.set_ylabel("Time (seconds)")
  150. ax1.set_title("Average Latency Comparison")
  151. ax1.set_xticks(x)
  152. ax1.set_xticklabels(methods, rotation=45, ha="right")
  153. ax1.legend()
  154. ax1.grid(True, alpha=0.3)
  155. ax2 = axes[0, 1]
  156. ax2.bar(x, p50, width, label="P50", color="#9C27B0", alpha=0.8)
  157. ax2.bar(x + width * 0.5, p95, width, label="P95", color="#E91E63", alpha=0.8)
  158. ax2.set_xlabel("Method")
  159. ax2.set_ylabel("Time (seconds)")
  160. ax2.set_title("Percentile Latency Comparison")
  161. ax2.set_xticks(x)
  162. ax2.set_xticklabels(methods, rotation=45, ha="right")
  163. ax2.legend()
  164. ax2.grid(True, alpha=0.3)
  165. ax3 = axes[1, 0]
  166. avg_docs = [m.avg_docs_retrieved for m in metrics]
  167. ax3.barh(methods, avg_docs, color="#00BCD4", alpha=0.8)
  168. ax3.set_xlabel("Average Number of Documents")
  169. ax3.set_title("Average Retrieved Documents")
  170. ax3.grid(True, alpha=0.3, axis="x")
  171. ax4 = axes[1, 1]
  172. speeds = [1.0 / m.avg_latency_total for m in metrics]
  173. ax4.barh(methods, speeds, color="#8BC34A", alpha=0.8)
  174. ax4.set_xlabel("Queries per Second")
  175. ax4.set_title("Throughput Comparison")
  176. ax4.grid(True, alpha=0.3, axis="x")
  177. plt.tight_layout()
  178. if save_path:
  179. os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
  180. plt.savefig(save_path, dpi=300, bbox_inches="tight")
  181. logger.info("Plot saved to {}".format(save_path))
  182. if show:
  183. plt.show()
  184. plt.close()
  185. def generate_report(self, benchmark_result, save_path=None):
  186. report = "# RAG Benchmark Report\n\n"
  187. report += "**Timestamp**: {}\n\n".format(benchmark_result.timestamp)
  188. report += "## Summary\n\n"
  189. report += "| Method | Avg Total(s) | Avg Retrieval(s) | Avg Generation(s) | P50(s) | P95(s) |\n"
  190. report += "|--------|-------------|-----------------|-------------------|--------|--------|\n"
  191. for m in sorted(benchmark_result.metrics, key=lambda x: x.avg_latency_total):
  192. report += "| {} | {:.3f} | {:.3f} | {:.3f} | {:.3f} | {:.3f} |\n".format(
  193. m.method_name, m.avg_latency_total, m.avg_latency_retrieval,
  194. m.avg_latency_generation, m.p50_latency, m.p95_latency)
  195. report += "\n## Detailed Analysis\n\n"
  196. fastest = min(benchmark_result.metrics, key=lambda x: x.avg_latency_total)
  197. report += "- **Fastest Method**: {} ({:.3f}s average)\n".format(fastest.method_name, fastest.avg_latency_total)
  198. most_docs = max(benchmark_result.metrics, key=lambda x: x.avg_docs_retrieved)
  199. report += "- **Most Documents Retrieved**: {} ({:.1f} average)\n".format(most_docs.method_name, most_docs.avg_docs_retrieved)
  200. most_stable = min(benchmark_result.metrics, key=lambda x: x.latency_std)
  201. report += "- **Most Stable**: {} (std={:.3f})\n".format(most_stable.method_name, most_stable.latency_std)
  202. if save_path:
  203. os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
  204. with open(save_path, "w", encoding="utf-8") as f:
  205. f.write(report)
  206. logger.info("Report saved to {}".format(save_path))
  207. return report