rag_test_utils.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317
  1. # -*- coding: utf-8 -*-
  2. """Shared helpers for standalone RAG example tests."""
  3. import math
  4. import os
  5. import re
  6. import sys
  7. import time
  8. import types
  9. ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
  10. if ROOT_DIR not in sys.path:
  11. sys.path.insert(0, ROOT_DIR)
  12. from bdirag.document_processor import Document
  13. from examples.sample_data import SAMPLE_BIDDING_DOCS
  14. TEST_QUERIES = [
  15. "project budget amount",
  16. "bid bond amount",
  17. "qualification requirements",
  18. "evaluation method",
  19. "warranty period",
  20. "payment terms",
  21. "project code XX-ZB",
  22. "delivery time",
  23. ]
  24. def install_rank_bm25_fallback():
  25. """Expose the repo BM25 fallback as rank_bm25 for methods that import it."""
  26. if "rank_bm25" in sys.modules:
  27. return
  28. from bdirag.rag_methods.bm25_backend import SimpleBM25Okapi
  29. module = types.ModuleType("rank_bm25")
  30. module.BM25Okapi = SimpleBM25Okapi
  31. sys.modules["rank_bm25"] = module
  32. class _ScoreList(list):
  33. def flatten(self):
  34. return self
  35. class _SimpleTfidfVectorizer(object):
  36. def __init__(self):
  37. self.vocabulary_ = {}
  38. def fit_transform(self, texts):
  39. for text in texts:
  40. for token in _tokens(text):
  41. if token not in self.vocabulary_:
  42. self.vocabulary_[token] = len(self.vocabulary_)
  43. return self.transform(texts)
  44. def transform(self, texts):
  45. vectors = []
  46. for text in texts:
  47. vector = [0.0] * len(self.vocabulary_)
  48. for token in _tokens(text):
  49. index = self.vocabulary_.get(token)
  50. if index is not None:
  51. vector[index] += 1.0
  52. norm = math.sqrt(sum(v * v for v in vector)) or 1.0
  53. vectors.append([v / norm for v in vector])
  54. return vectors
  55. def _simple_cosine_similarity(query_vecs, matrix):
  56. query = query_vecs[0] if query_vecs else []
  57. scores = _ScoreList()
  58. for vector in matrix:
  59. scores.append(sum(a * b for a, b in zip(query, vector)))
  60. return scores
  61. class _SimpleKMeans(object):
  62. def __init__(self, n_clusters=2, random_state=None):
  63. self.n_clusters = max(1, int(n_clusters))
  64. self.random_state = random_state
  65. def fit_predict(self, embeddings):
  66. return [i % self.n_clusters for i in range(len(embeddings))]
  67. def install_sklearn_fallback():
  68. """Install tiny sklearn-compatible modules used by the example tests."""
  69. sklearn = types.ModuleType("sklearn")
  70. feature_extraction = types.ModuleType("sklearn.feature_extraction")
  71. text = types.ModuleType("sklearn.feature_extraction.text")
  72. metrics = types.ModuleType("sklearn.metrics")
  73. pairwise = types.ModuleType("sklearn.metrics.pairwise")
  74. cluster = types.ModuleType("sklearn.cluster")
  75. text.TfidfVectorizer = _SimpleTfidfVectorizer
  76. pairwise.cosine_similarity = _simple_cosine_similarity
  77. cluster.KMeans = _SimpleKMeans
  78. sklearn.feature_extraction = feature_extraction
  79. sklearn.metrics = metrics
  80. sklearn.cluster = cluster
  81. feature_extraction.text = text
  82. metrics.pairwise = pairwise
  83. sys.modules["sklearn"] = sklearn
  84. sys.modules["sklearn.feature_extraction"] = feature_extraction
  85. sys.modules["sklearn.feature_extraction.text"] = text
  86. sys.modules["sklearn.metrics"] = metrics
  87. sys.modules["sklearn.metrics.pairwise"] = pairwise
  88. sys.modules["sklearn.cluster"] = cluster
  89. def make_documents():
  90. return [
  91. Document(
  92. page_content=doc["content"],
  93. metadata={"title": doc["title"], "source": doc["title"]},
  94. )
  95. for doc in SAMPLE_BIDDING_DOCS
  96. ]
  97. def _tokens(text):
  98. return re.findall(r"[A-Za-z0-9]+", (text or "").lower())
  99. class FakeEmbedding(object):
  100. def __init__(self, dimension=64):
  101. self.dimension = dimension
  102. def _embed(self, text):
  103. vector = [0.0] * self.dimension
  104. for token in _tokens(text):
  105. index = sum(ord(ch) for ch in token) % self.dimension
  106. vector[index] += 1.0
  107. norm = math.sqrt(sum(v * v for v in vector)) or 1.0
  108. return [v / norm for v in vector]
  109. def embed_documents(self, texts):
  110. return [self._embed(text) for text in texts]
  111. def embed_query(self, text):
  112. return self._embed(text)
  113. class SimpleVectorStore(object):
  114. def __init__(self):
  115. self.documents = []
  116. self.embeddings = []
  117. def add_documents(self, documents, embeddings):
  118. self.documents.extend(documents)
  119. self.embeddings.extend(embeddings)
  120. def similarity_search(self, query_embedding, k=10):
  121. scored = []
  122. for doc, embedding in zip(self.documents, self.embeddings):
  123. score = sum(a * b for a, b in zip(query_embedding, embedding))
  124. if score > 0:
  125. scored.append((doc, float(score)))
  126. scored.sort(key=lambda item: item[1], reverse=True)
  127. return scored[:k]
  128. def save(self, path):
  129. return None
  130. def load(self, path):
  131. return None
  132. class _FakeMessage(object):
  133. def __init__(self, content):
  134. self.content = content
  135. class _FakeChoice(object):
  136. def __init__(self, content):
  137. self.message = _FakeMessage(content)
  138. class _FakeResponse(object):
  139. def __init__(self, content):
  140. self.choices = [_FakeChoice(content)]
  141. class _FakeCompletions(object):
  142. def create(self, model=None, messages=None, temperature=None, max_tokens=None):
  143. prompt = messages[-1]["content"] if messages else ""
  144. return _FakeResponse(fake_llm_text(prompt))
  145. class _FakeChat(object):
  146. def __init__(self):
  147. self.completions = _FakeCompletions()
  148. class FakeLLMClient(object):
  149. def __init__(self):
  150. self.chat = _FakeChat()
  151. def fake_llm_text(prompt):
  152. lower = (prompt or "").lower()
  153. if "json" in lower:
  154. return (
  155. '{"project_name": "sample bidding project", '
  156. '"budget_amount": "sample budget", '
  157. '"evaluation_method": "sample evaluation method"}'
  158. )
  159. if "0-1" in prompt or "0-1" in lower or "score" in lower:
  160. return "0.8"
  161. if "category" in lower or "classify" in lower:
  162. return "budget"
  163. if "simple" in lower and "complex" in lower:
  164. return "simple"
  165. if "sub" in lower or "variant" in lower:
  166. return "project budget\nqualification requirements\ndelivery time"
  167. if "step" in lower or "general" in lower:
  168. return "general bidding project information"
  169. return "Sample offline LLM answer for the standalone RAG test."
  170. class FakeRerankModel(object):
  171. def compute_score(self, pairs):
  172. scores = []
  173. for query, text in pairs:
  174. query_words = set(_tokens(query))
  175. text_words = set(_tokens(text))
  176. scores.append(float(len(query_words & text_words)) + 0.1)
  177. return scores
  178. def build_vector_rag(rag_cls, **kwargs):
  179. options = {
  180. "embedding_model": FakeEmbedding(),
  181. "vector_store": SimpleVectorStore(),
  182. "llm_client": FakeLLMClient(),
  183. }
  184. options.update(kwargs)
  185. return rag_cls(**options)
  186. def print_results(method_name, query, results, elapsed):
  187. print("\nQuery: {}".format(query))
  188. print("-" * 60)
  189. print(" Retrieved {} documents in {:.4f}s".format(len(results), elapsed))
  190. for i, (doc, score) in enumerate(results, 1):
  191. title = doc.metadata.get("title", doc.metadata.get("source", "Unknown"))
  192. preview = doc.page_content[:100].replace("\n", " ")
  193. print(" [{}] {} (Score: {:.4f})".format(i, title, float(score)))
  194. print(" Preview: {}...".format(preview))
  195. def jinrun_retrieval_test(method_name, rag, index_func=None, queries=None, k=3):
  196. print("=" * 60)
  197. print("{} - Standalone Retrieval Test".format(method_name))
  198. print("=" * 60)
  199. documents = make_documents()
  200. print("\n[1/2] Preparing documents...")
  201. print(" Prepared {} documents".format(len(documents)))
  202. print("\n[2/2] Building index...")
  203. if index_func is None:
  204. rag.index_documents(documents)
  205. else:
  206. index_func(rag, documents)
  207. print(" Index built successfully")
  208. print("\n" + "=" * 60)
  209. print("{} Retrieval Test Results".format(method_name))
  210. print("=" * 60)
  211. for query in queries or TEST_QUERIES:
  212. start = time.time()
  213. results = rag.retrieve(query, k=k)
  214. elapsed = time.time() - start
  215. print_results(method_name, query, results, elapsed)
  216. print("\n{} test complete!".format(method_name))
  217. run_retrieval_test = jinrun_retrieval_test
  218. def run_html_tree_test(rag_cls):
  219. sample_html = """
  220. <html><body>
  221. <h1>Sample Medical Equipment Procurement Bidding Announcement</h1>
  222. <h2>Project Overview</h2>
  223. <p>Project budget: 28,000,000 RMB</p>
  224. <p>Project code: XX-HOSP-2024-015</p>
  225. <h2>Qualification Requirements</h2>
  226. <p>Must have Medical Device Operation License and ISO 13485 certification.</p>
  227. <h2>Delivery and Warranty</h2>
  228. <p>Delivery time: within 90 calendar days. Warranty period: minimum 5 years.</p>
  229. <h2>Payment Terms</h2>
  230. <p>100% payment after installation, debugging and acceptance.</p>
  231. </body></html>
  232. """
  233. print("=" * 60)
  234. print("{} - Standalone HTML Tree Test".format(rag_cls.__name__))
  235. print("=" * 60)
  236. rag = rag_cls()
  237. rag.build_index(sample_html)
  238. print(" HTML tree index built successfully")
  239. for query in ["project budget", "qualification requirements", "delivery time", "payment terms"]:
  240. start = time.time()
  241. results = rag.query(query, k=3)
  242. elapsed = time.time() - start
  243. print_results(rag_cls.__name__, query, results, elapsed)
  244. print("\n{} test complete!".format(rag_cls.__name__))