| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257 |
- import time
- import json
- import os
- from typing import List, Dict, Any, Optional
- from dataclasses import dataclass, asdict
- import numpy as np
- from loguru import logger
- from .rag_methods import BaseRAG, RAGResult
- @dataclass
- class BenchmarkMetrics:
- method_name: str
- avg_latency_total: float
- avg_latency_retrieval: float
- avg_latency_generation: float
- avg_docs_retrieved: float
- total_queries: int
- latency_std: float
- retrieval_std: float
- generation_std: float
- min_latency: float
- max_latency: float
- p50_latency: float
- p95_latency: float
- def to_dict(self):
- return asdict(self)
- @dataclass
- class BenchmarkResult:
- metrics: List[BenchmarkMetrics]
- detailed_results: Dict[str, List[RAGResult]]
- timestamp: str
- def to_dict(self):
- return {
- "metrics": [m.to_dict() for m in self.metrics],
- "timestamp": self.timestamp,
- }
- def save(self, path):
- os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
- with open(path, "w", encoding="utf-8") as f:
- json.dump(self.to_dict(), f, ensure_ascii=False, indent=2)
- logger.info("Benchmark results saved to {}".format(path))
- class RAGBenchmark:
- def __init__(self):
- self.results = {}
- def run_single_query(self, rag_method, query, k=10):
- result = rag_method.query(query, k)
- return result
- def run_benchmark(self, rag_methods, queries, k=10, verbose=True):
- self.results = {}
- for method in rag_methods:
- method_results = []
- if verbose:
- logger.info("Running benchmark for: {}".format(method.name))
- for i, query in enumerate(queries):
- if verbose:
- logger.info(" Query {}/{}: {}...".format(i + 1, len(queries), query[:50]))
- try:
- result = self.run_single_query(method, query, k)
- method_results.append(result)
- if verbose:
- logger.info(" Answer: {}...".format(result.answer[:80]))
- logger.info(" Latency: {:.3f}s".format(result.latency_total))
- except Exception as e:
- logger.error(" Error on query {} for {}: {}".format(i + 1, method.name, e))
- self.results[method.name] = method_results
- metrics = self._compute_metrics()
- import datetime
- benchmark_result = BenchmarkResult(
- metrics=metrics,
- detailed_results=self.results,
- timestamp=datetime.datetime.now().isoformat(),
- )
- if verbose:
- self._print_summary(metrics)
- return benchmark_result
- def _compute_metrics(self):
- metrics = []
- for method_name, results in self.results.items():
- if not results:
- continue
- total_latencies = [r.latency_total for r in results]
- retrieval_latencies = [r.latency_retrieval for r in results]
- generation_latencies = [r.latency_generation for r in results]
- num_docs = [len(r.retrieved_docs) for r in results]
- metric = BenchmarkMetrics(
- method_name=method_name,
- avg_latency_total=np.mean(total_latencies),
- avg_latency_retrieval=np.mean(retrieval_latencies),
- avg_latency_generation=np.mean(generation_latencies),
- avg_docs_retrieved=np.mean(num_docs),
- total_queries=len(results),
- latency_std=np.std(total_latencies),
- retrieval_std=np.std(retrieval_latencies),
- generation_std=np.std(generation_latencies),
- min_latency=np.min(total_latencies),
- max_latency=np.max(total_latencies),
- p50_latency=np.percentile(total_latencies, 50),
- p95_latency=np.percentile(total_latencies, 95),
- )
- metrics.append(metric)
- return metrics
- def _print_summary(self, metrics):
- from rich.console import Console
- from rich.table import Table
- console = Console()
- table = Table(title="RAG Methods Benchmark Results")
- table.add_column("Method", style="cyan")
- table.add_column("Avg Total(s)", justify="right", style="green")
- table.add_column("Avg Retrieval(s)", justify="right", style="green")
- table.add_column("Avg Generation(s)", justify="right", style="green")
- table.add_column("P50(s)", justify="right", style="yellow")
- table.add_column("P95(s)", justify="right", style="yellow")
- table.add_column("Min(s)", justify="right", style="magenta")
- table.add_column("Max(s)", justify="right", style="magenta")
- table.add_column("Avg Docs", justify="right", style="blue")
- for m in sorted(metrics, key=lambda x: x.avg_latency_total):
- table.add_row(
- m.method_name,
- "{:.3f}".format(m.avg_latency_total),
- "{:.3f}".format(m.avg_latency_retrieval),
- "{:.3f}".format(m.avg_latency_generation),
- "{:.3f}".format(m.p50_latency),
- "{:.3f}".format(m.p95_latency),
- "{:.3f}".format(m.min_latency),
- "{:.3f}".format(m.max_latency),
- "{:.1f}".format(m.avg_docs_retrieved),
- )
- console.print(table)
- def plot_comparison(self, metrics, save_path=None, show=True):
- import matplotlib.pyplot as plt
- import matplotlib
- matplotlib.rcParams["font.sans-serif"] = ["SimHei", "Arial Unicode MS"]
- matplotlib.rcParams["axes.unicode_minus"] = False
- methods = [m.method_name for m in metrics]
- avg_total = [m.avg_latency_total for m in metrics]
- avg_retrieval = [m.avg_latency_retrieval for m in metrics]
- avg_generation = [m.avg_latency_generation for m in metrics]
- p50 = [m.p50_latency for m in metrics]
- p95 = [m.p95_latency for m in metrics]
- fig, axes = plt.subplots(2, 2, figsize=(14, 10))
- fig.suptitle("RAG Methods Performance Comparison", fontsize=16, fontweight="bold")
- colors = plt.cm.Set3(np.linspace(0, 1, len(methods)))
- x = np.arange(len(methods))
- width = 0.25
- ax1 = axes[0, 0]
- bars1 = ax1.bar(x - width, avg_retrieval, width, label="Retrieval", color="#4CAF50", alpha=0.8)
- bars2 = ax1.bar(x, avg_generation, width, label="Generation", color="#2196F3", alpha=0.8)
- bars3 = ax1.bar(x + width, avg_total, width, label="Total", color="#FF9800", alpha=0.8)
- ax1.set_xlabel("Method")
- ax1.set_ylabel("Time (seconds)")
- ax1.set_title("Average Latency Comparison")
- ax1.set_xticks(x)
- ax1.set_xticklabels(methods, rotation=45, ha="right")
- ax1.legend()
- ax1.grid(True, alpha=0.3)
- ax2 = axes[0, 1]
- ax2.bar(x, p50, width, label="P50", color="#9C27B0", alpha=0.8)
- ax2.bar(x + width * 0.5, p95, width, label="P95", color="#E91E63", alpha=0.8)
- ax2.set_xlabel("Method")
- ax2.set_ylabel("Time (seconds)")
- ax2.set_title("Percentile Latency Comparison")
- ax2.set_xticks(x)
- ax2.set_xticklabels(methods, rotation=45, ha="right")
- ax2.legend()
- ax2.grid(True, alpha=0.3)
- ax3 = axes[1, 0]
- avg_docs = [m.avg_docs_retrieved for m in metrics]
- ax3.barh(methods, avg_docs, color="#00BCD4", alpha=0.8)
- ax3.set_xlabel("Average Number of Documents")
- ax3.set_title("Average Retrieved Documents")
- ax3.grid(True, alpha=0.3, axis="x")
- ax4 = axes[1, 1]
- speeds = [1.0 / m.avg_latency_total for m in metrics]
- ax4.barh(methods, speeds, color="#8BC34A", alpha=0.8)
- ax4.set_xlabel("Queries per Second")
- ax4.set_title("Throughput Comparison")
- ax4.grid(True, alpha=0.3, axis="x")
- plt.tight_layout()
- if save_path:
- os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
- plt.savefig(save_path, dpi=300, bbox_inches="tight")
- logger.info("Plot saved to {}".format(save_path))
- if show:
- plt.show()
- plt.close()
- def generate_report(self, benchmark_result, save_path=None):
- report = "# RAG Benchmark Report\n\n"
- report += "**Timestamp**: {}\n\n".format(benchmark_result.timestamp)
- report += "## Summary\n\n"
- report += "| Method | Avg Total(s) | Avg Retrieval(s) | Avg Generation(s) | P50(s) | P95(s) |\n"
- report += "|--------|-------------|-----------------|-------------------|--------|--------|\n"
- for m in sorted(benchmark_result.metrics, key=lambda x: x.avg_latency_total):
- report += "| {} | {:.3f} | {:.3f} | {:.3f} | {:.3f} | {:.3f} |\n".format(
- m.method_name, m.avg_latency_total, m.avg_latency_retrieval,
- m.avg_latency_generation, m.p50_latency, m.p95_latency)
- report += "\n## Detailed Analysis\n\n"
- fastest = min(benchmark_result.metrics, key=lambda x: x.avg_latency_total)
- report += "- **Fastest Method**: {} ({:.3f}s average)\n".format(fastest.method_name, fastest.avg_latency_total)
- most_docs = max(benchmark_result.metrics, key=lambda x: x.avg_docs_retrieved)
- report += "- **Most Documents Retrieved**: {} ({:.1f} average)\n".format(most_docs.method_name, most_docs.avg_docs_retrieved)
- most_stable = min(benchmark_result.metrics, key=lambda x: x.latency_std)
- report += "- **Most Stable**: {} (std={:.3f})\n".format(most_stable.method_name, most_stable.latency_std)
- if save_path:
- os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
- with open(save_path, "w", encoding="utf-8") as f:
- f.write(report)
- logger.info("Report saved to {}".format(save_path))
- return report
|