import time import json import os from typing import List, Dict, Any, Optional from dataclasses import dataclass, asdict import numpy as np from loguru import logger from .rag_methods import BaseRAG, RAGResult @dataclass class BenchmarkMetrics: method_name: str avg_latency_total: float avg_latency_retrieval: float avg_latency_generation: float avg_docs_retrieved: float total_queries: int latency_std: float retrieval_std: float generation_std: float min_latency: float max_latency: float p50_latency: float p95_latency: float def to_dict(self): return asdict(self) @dataclass class BenchmarkResult: metrics: List[BenchmarkMetrics] detailed_results: Dict[str, List[RAGResult]] timestamp: str def to_dict(self): return { "metrics": [m.to_dict() for m in self.metrics], "timestamp": self.timestamp, } def save(self, path): os.makedirs(os.path.dirname(path) or ".", exist_ok=True) with open(path, "w", encoding="utf-8") as f: json.dump(self.to_dict(), f, ensure_ascii=False, indent=2) logger.info("Benchmark results saved to {}".format(path)) class RAGBenchmark: def __init__(self): self.results = {} def run_single_query(self, rag_method, query, k=10): result = rag_method.query(query, k) return result def run_benchmark(self, rag_methods, queries, k=10, verbose=True): self.results = {} for method in rag_methods: method_results = [] if verbose: logger.info("Running benchmark for: {}".format(method.name)) for i, query in enumerate(queries): if verbose: logger.info(" Query {}/{}: {}...".format(i + 1, len(queries), query[:50])) try: result = self.run_single_query(method, query, k) method_results.append(result) if verbose: logger.info(" Answer: {}...".format(result.answer[:80])) logger.info(" Latency: {:.3f}s".format(result.latency_total)) except Exception as e: logger.error(" Error on query {} for {}: {}".format(i + 1, method.name, e)) self.results[method.name] = method_results metrics = self._compute_metrics() import datetime benchmark_result = BenchmarkResult( metrics=metrics, detailed_results=self.results, timestamp=datetime.datetime.now().isoformat(), ) if verbose: self._print_summary(metrics) return benchmark_result def _compute_metrics(self): metrics = [] for method_name, results in self.results.items(): if not results: continue total_latencies = [r.latency_total for r in results] retrieval_latencies = [r.latency_retrieval for r in results] generation_latencies = [r.latency_generation for r in results] num_docs = [len(r.retrieved_docs) for r in results] metric = BenchmarkMetrics( method_name=method_name, avg_latency_total=np.mean(total_latencies), avg_latency_retrieval=np.mean(retrieval_latencies), avg_latency_generation=np.mean(generation_latencies), avg_docs_retrieved=np.mean(num_docs), total_queries=len(results), latency_std=np.std(total_latencies), retrieval_std=np.std(retrieval_latencies), generation_std=np.std(generation_latencies), min_latency=np.min(total_latencies), max_latency=np.max(total_latencies), p50_latency=np.percentile(total_latencies, 50), p95_latency=np.percentile(total_latencies, 95), ) metrics.append(metric) return metrics def _print_summary(self, metrics): from rich.console import Console from rich.table import Table console = Console() table = Table(title="RAG Methods Benchmark Results") table.add_column("Method", style="cyan") table.add_column("Avg Total(s)", justify="right", style="green") table.add_column("Avg Retrieval(s)", justify="right", style="green") table.add_column("Avg Generation(s)", justify="right", style="green") table.add_column("P50(s)", justify="right", style="yellow") table.add_column("P95(s)", justify="right", style="yellow") table.add_column("Min(s)", justify="right", style="magenta") table.add_column("Max(s)", justify="right", style="magenta") table.add_column("Avg Docs", justify="right", style="blue") for m in sorted(metrics, key=lambda x: x.avg_latency_total): table.add_row( m.method_name, "{:.3f}".format(m.avg_latency_total), "{:.3f}".format(m.avg_latency_retrieval), "{:.3f}".format(m.avg_latency_generation), "{:.3f}".format(m.p50_latency), "{:.3f}".format(m.p95_latency), "{:.3f}".format(m.min_latency), "{:.3f}".format(m.max_latency), "{:.1f}".format(m.avg_docs_retrieved), ) console.print(table) def plot_comparison(self, metrics, save_path=None, show=True): import matplotlib.pyplot as plt import matplotlib matplotlib.rcParams["font.sans-serif"] = ["SimHei", "Arial Unicode MS"] matplotlib.rcParams["axes.unicode_minus"] = False methods = [m.method_name for m in metrics] avg_total = [m.avg_latency_total for m in metrics] avg_retrieval = [m.avg_latency_retrieval for m in metrics] avg_generation = [m.avg_latency_generation for m in metrics] p50 = [m.p50_latency for m in metrics] p95 = [m.p95_latency for m in metrics] fig, axes = plt.subplots(2, 2, figsize=(14, 10)) fig.suptitle("RAG Methods Performance Comparison", fontsize=16, fontweight="bold") colors = plt.cm.Set3(np.linspace(0, 1, len(methods))) x = np.arange(len(methods)) width = 0.25 ax1 = axes[0, 0] bars1 = ax1.bar(x - width, avg_retrieval, width, label="Retrieval", color="#4CAF50", alpha=0.8) bars2 = ax1.bar(x, avg_generation, width, label="Generation", color="#2196F3", alpha=0.8) bars3 = ax1.bar(x + width, avg_total, width, label="Total", color="#FF9800", alpha=0.8) ax1.set_xlabel("Method") ax1.set_ylabel("Time (seconds)") ax1.set_title("Average Latency Comparison") ax1.set_xticks(x) ax1.set_xticklabels(methods, rotation=45, ha="right") ax1.legend() ax1.grid(True, alpha=0.3) ax2 = axes[0, 1] ax2.bar(x, p50, width, label="P50", color="#9C27B0", alpha=0.8) ax2.bar(x + width * 0.5, p95, width, label="P95", color="#E91E63", alpha=0.8) ax2.set_xlabel("Method") ax2.set_ylabel("Time (seconds)") ax2.set_title("Percentile Latency Comparison") ax2.set_xticks(x) ax2.set_xticklabels(methods, rotation=45, ha="right") ax2.legend() ax2.grid(True, alpha=0.3) ax3 = axes[1, 0] avg_docs = [m.avg_docs_retrieved for m in metrics] ax3.barh(methods, avg_docs, color="#00BCD4", alpha=0.8) ax3.set_xlabel("Average Number of Documents") ax3.set_title("Average Retrieved Documents") ax3.grid(True, alpha=0.3, axis="x") ax4 = axes[1, 1] speeds = [1.0 / m.avg_latency_total for m in metrics] ax4.barh(methods, speeds, color="#8BC34A", alpha=0.8) ax4.set_xlabel("Queries per Second") ax4.set_title("Throughput Comparison") ax4.grid(True, alpha=0.3, axis="x") plt.tight_layout() if save_path: os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True) plt.savefig(save_path, dpi=300, bbox_inches="tight") logger.info("Plot saved to {}".format(save_path)) if show: plt.show() plt.close() def generate_report(self, benchmark_result, save_path=None): report = "# RAG Benchmark Report\n\n" report += "**Timestamp**: {}\n\n".format(benchmark_result.timestamp) report += "## Summary\n\n" report += "| Method | Avg Total(s) | Avg Retrieval(s) | Avg Generation(s) | P50(s) | P95(s) |\n" report += "|--------|-------------|-----------------|-------------------|--------|--------|\n" for m in sorted(benchmark_result.metrics, key=lambda x: x.avg_latency_total): report += "| {} | {:.3f} | {:.3f} | {:.3f} | {:.3f} | {:.3f} |\n".format( m.method_name, m.avg_latency_total, m.avg_latency_retrieval, m.avg_latency_generation, m.p50_latency, m.p95_latency) report += "\n## Detailed Analysis\n\n" fastest = min(benchmark_result.metrics, key=lambda x: x.avg_latency_total) report += "- **Fastest Method**: {} ({:.3f}s average)\n".format(fastest.method_name, fastest.avg_latency_total) most_docs = max(benchmark_result.metrics, key=lambda x: x.avg_docs_retrieved) report += "- **Most Documents Retrieved**: {} ({:.1f} average)\n".format(most_docs.method_name, most_docs.avg_docs_retrieved) most_stable = min(benchmark_result.metrics, key=lambda x: x.latency_std) report += "- **Most Stable**: {} (std={:.3f})\n".format(most_stable.method_name, most_stable.latency_std) if save_path: os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True) with open(save_path, "w", encoding="utf-8") as f: f.write(report) logger.info("Report saved to {}".format(save_path)) return report