bid_field_extraction_demo.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. """
  2. Bid field extraction demo - demonstrates structured information extraction
  3. from bidding announcements using RAG
  4. """
  5. import sys
  6. import os
  7. import json
  8. sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
  9. from openai import OpenAI
  10. from bdirag.document_processor import Document
  11. from bdirag.embedding_models import SentenceTransformerEmbedding
  12. from bdirag.vector_stores import FAISSStore
  13. from bdirag.rag_methods import BidFieldExtractionRAG
  14. from examples.sample_data import SAMPLE_BIDDING_DOCS
  15. def main():
  16. print("=" * 60)
  17. print("BidiRAG - Bid Field Extraction Demo")
  18. print("=" * 60)
  19. # Configuration
  20. LLM_API_KEY = os.getenv("OPENAI_API_KEY", "your-api-key-here")
  21. LLM_BASE_URL = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
  22. LLM_MODEL = os.getenv("LLM_MODEL", "gpt-4o")
  23. # Load embedding model
  24. print("\n[1/3] Loading embedding model...")
  25. embedding_model = SentenceTransformerEmbedding(
  26. model_name="BAAI/bge-large-zh-v1.5",
  27. device="cpu"
  28. )
  29. # Create vector store and index
  30. print("\n[2/3] Indexing bidding documents...")
  31. vector_store = FAISSStore(embedding_model=embedding_model)
  32. documents = [
  33. Document(page_content=doc["content"], metadata={"title": doc["title"], "source": doc["title"]})
  34. for doc in SAMPLE_BIDDING_DOCS
  35. ]
  36. llm_client = OpenAI(api_key=LLM_API_KEY, base_url=LLM_BASE_URL)
  37. rag = BidFieldExtractionRAG(
  38. embedding_model=embedding_model,
  39. vector_store=vector_store,
  40. llm_client=llm_client,
  41. llm_model=LLM_MODEL,
  42. )
  43. rag.index_documents(documents)
  44. print(" Indexed {0} documents.format(len(documents))")
  45. # Extract fields for each bidding document
  46. print("\n[3/3] Extracting fields from bidding documents...")
  47. for doc in SAMPLE_BIDDING_DOCS:
  48. print("\n{0}.format('=' * 60)")
  49. print("Document: {0}.format(doc['title'])")
  50. print("{0}.format('=' * 60)")
  51. query = "Extract all information from {0}.format(doc['title'])"
  52. try:
  53. result = rag.query(query, k=10)
  54. print(f"\nExtracted JSON:")
  55. print(result.answer)
  56. print("\nLatency: {0}s.format(result.latency_total:.3f)")
  57. print("Retrieved {0} document chunks.format(len(result.retrieved_docs))")
  58. except Exception as e:
  59. print("ERROR: {0}.format(e)")
  60. print("\n\nExtraction complete!")
  61. if __name__ == "__main__":
  62. main()