quick_demo.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. """
  2. Simple demo script to test basic RAG methods without full benchmark
  3. Good for quick validation and understanding
  4. """
  5. import sys
  6. import os
  7. sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
  8. from openai import OpenAI
  9. from bdirag.document_processor import Document
  10. from bdirag.embedding_models import SentenceTransformerEmbedding
  11. from bdirag.vector_stores import FAISSStore
  12. from bdirag.rag_methods import (
  13. NaiveRAG,
  14. BidFieldExtractionRAG,
  15. HyDERAG,
  16. )
  17. from examples.sample_data import SAMPLE_BIDDING_DOCS
  18. def main():
  19. print("=" * 60)
  20. print("BidiRAG - Quick Demo")
  21. print("=" * 60)
  22. # Configuration - modify these as needed
  23. LLM_API_KEY = os.getenv("OPENAI_API_KEY", "your-api-key-here")
  24. LLM_BASE_URL = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
  25. LLM_MODEL = os.getenv("LLM_MODEL", "gpt-4o")
  26. # Step 1: Load embedding model
  27. print("\n[1/4] Loading embedding model...")
  28. embedding_model = SentenceTransformerEmbedding(
  29. model_name="BAAI/bge-large-zh-v1.5",
  30. device="cpu"
  31. )
  32. print(" Model loaded, dimension: {0}.format(embedding_model.dimension)")
  33. # Step 2: Create vector store and index documents
  34. print("\n[2/4] Creating vector store and indexing documents...")
  35. vector_store = FAISSStore(embedding_model=embedding_model)
  36. documents = [
  37. Document(page_content=doc["content"], metadata={"title": doc["title"], "source": doc["title"]})
  38. for doc in SAMPLE_BIDDING_DOCS
  39. ]
  40. print(" Prepared {0} documents.format(len(documents))")
  41. # Step 3: Initialize RAG methods
  42. print("\n[3/4] Initializing RAG methods...")
  43. llm_client = OpenAI(api_key=LLM_API_KEY, base_url=LLM_BASE_URL)
  44. naive_rag = NaiveRAG(
  45. embedding_model=embedding_model,
  46. vector_store=vector_store,
  47. llm_client=llm_client,
  48. llm_model=LLM_MODEL,
  49. )
  50. bid_rag = BidFieldExtractionRAG(
  51. embedding_model=embedding_model,
  52. vector_store=vector_store,
  53. llm_client=llm_client,
  54. llm_model=LLM_MODEL,
  55. )
  56. hyde_rag = HyDERAG(
  57. embedding_model=embedding_model,
  58. vector_store=vector_store,
  59. llm_client=llm_client,
  60. llm_model=LLM_MODEL,
  61. )
  62. naive_rag.index_documents(documents)
  63. bid_rag.index_documents(documents)
  64. hyde_rag.index_documents(documents)
  65. print(" Indexing complete")
  66. # Step 4: Test queries
  67. print("\n[4/4] Running test queries...")
  68. queries = [
  69. "What is the budget for the smart transportation project?",
  70. "List the qualification requirements for all projects",
  71. "What are the payment terms for the road construction project?",
  72. ]
  73. methods = [
  74. ("NaiveRAG", naive_rag),
  75. ("BidFieldExtractionRAG", bid_rag),
  76. ("HyDERAG", hyde_rag),
  77. ]
  78. for query in queries:
  79. print("\n{0}.format('=' * 60)")
  80. print("Query: {0}.format(query)")
  81. print("{0}.format('=' * 60)")
  82. for method_name, method in methods:
  83. print("\n--- {0} ---.format(method_name)")
  84. try:
  85. result = method.query(query, k=5)
  86. print("Answer: {0}.format(result.answer)")
  87. print("Latency: {0}s (retrieval: {1}s, generation: {2}s).format(result.latency_total:.3f, result.latency_retrieval:.3f, result.latency_generation:.3f)")
  88. print("Retrieved {0} documents.format(len(result.retrieved_docs))")
  89. except Exception as e:
  90. print("ERROR: {0}.format(e)")
  91. print("\n\nDemo complete!")
  92. if __name__ == "__main__":
  93. main()