| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317 |
- # -*- coding: utf-8 -*-
- """Shared helpers for standalone RAG example tests."""
- import math
- import os
- import re
- import sys
- import time
- import types
- ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
- if ROOT_DIR not in sys.path:
- sys.path.insert(0, ROOT_DIR)
- from bdirag.document_processor import Document
- from examples.sample_data import SAMPLE_BIDDING_DOCS
- TEST_QUERIES = [
- "project budget amount",
- "bid bond amount",
- "qualification requirements",
- "evaluation method",
- "warranty period",
- "payment terms",
- "project code XX-ZB",
- "delivery time",
- ]
- def install_rank_bm25_fallback():
- """Expose the repo BM25 fallback as rank_bm25 for methods that import it."""
- if "rank_bm25" in sys.modules:
- return
- from bdirag.rag_methods.bm25_backend import SimpleBM25Okapi
- module = types.ModuleType("rank_bm25")
- module.BM25Okapi = SimpleBM25Okapi
- sys.modules["rank_bm25"] = module
- class _ScoreList(list):
- def flatten(self):
- return self
- class _SimpleTfidfVectorizer(object):
- def __init__(self):
- self.vocabulary_ = {}
- def fit_transform(self, texts):
- for text in texts:
- for token in _tokens(text):
- if token not in self.vocabulary_:
- self.vocabulary_[token] = len(self.vocabulary_)
- return self.transform(texts)
- def transform(self, texts):
- vectors = []
- for text in texts:
- vector = [0.0] * len(self.vocabulary_)
- for token in _tokens(text):
- index = self.vocabulary_.get(token)
- if index is not None:
- vector[index] += 1.0
- norm = math.sqrt(sum(v * v for v in vector)) or 1.0
- vectors.append([v / norm for v in vector])
- return vectors
- def _simple_cosine_similarity(query_vecs, matrix):
- query = query_vecs[0] if query_vecs else []
- scores = _ScoreList()
- for vector in matrix:
- scores.append(sum(a * b for a, b in zip(query, vector)))
- return scores
- class _SimpleKMeans(object):
- def __init__(self, n_clusters=2, random_state=None):
- self.n_clusters = max(1, int(n_clusters))
- self.random_state = random_state
- def fit_predict(self, embeddings):
- return [i % self.n_clusters for i in range(len(embeddings))]
- def install_sklearn_fallback():
- """Install tiny sklearn-compatible modules used by the example tests."""
- sklearn = types.ModuleType("sklearn")
- feature_extraction = types.ModuleType("sklearn.feature_extraction")
- text = types.ModuleType("sklearn.feature_extraction.text")
- metrics = types.ModuleType("sklearn.metrics")
- pairwise = types.ModuleType("sklearn.metrics.pairwise")
- cluster = types.ModuleType("sklearn.cluster")
- text.TfidfVectorizer = _SimpleTfidfVectorizer
- pairwise.cosine_similarity = _simple_cosine_similarity
- cluster.KMeans = _SimpleKMeans
- sklearn.feature_extraction = feature_extraction
- sklearn.metrics = metrics
- sklearn.cluster = cluster
- feature_extraction.text = text
- metrics.pairwise = pairwise
- sys.modules["sklearn"] = sklearn
- sys.modules["sklearn.feature_extraction"] = feature_extraction
- sys.modules["sklearn.feature_extraction.text"] = text
- sys.modules["sklearn.metrics"] = metrics
- sys.modules["sklearn.metrics.pairwise"] = pairwise
- sys.modules["sklearn.cluster"] = cluster
- def make_documents():
- return [
- Document(
- page_content=doc["content"],
- metadata={"title": doc["title"], "source": doc["title"]},
- )
- for doc in SAMPLE_BIDDING_DOCS
- ]
- def _tokens(text):
- return re.findall(r"[A-Za-z0-9]+", (text or "").lower())
- class FakeEmbedding(object):
- def __init__(self, dimension=64):
- self.dimension = dimension
- def _embed(self, text):
- vector = [0.0] * self.dimension
- for token in _tokens(text):
- index = sum(ord(ch) for ch in token) % self.dimension
- vector[index] += 1.0
- norm = math.sqrt(sum(v * v for v in vector)) or 1.0
- return [v / norm for v in vector]
- def embed_documents(self, texts):
- return [self._embed(text) for text in texts]
- def embed_query(self, text):
- return self._embed(text)
- class SimpleVectorStore(object):
- def __init__(self):
- self.documents = []
- self.embeddings = []
- def add_documents(self, documents, embeddings):
- self.documents.extend(documents)
- self.embeddings.extend(embeddings)
- def similarity_search(self, query_embedding, k=10):
- scored = []
- for doc, embedding in zip(self.documents, self.embeddings):
- score = sum(a * b for a, b in zip(query_embedding, embedding))
- if score > 0:
- scored.append((doc, float(score)))
- scored.sort(key=lambda item: item[1], reverse=True)
- return scored[:k]
- def save(self, path):
- return None
- def load(self, path):
- return None
- class _FakeMessage(object):
- def __init__(self, content):
- self.content = content
- class _FakeChoice(object):
- def __init__(self, content):
- self.message = _FakeMessage(content)
- class _FakeResponse(object):
- def __init__(self, content):
- self.choices = [_FakeChoice(content)]
- class _FakeCompletions(object):
- def create(self, model=None, messages=None, temperature=None, max_tokens=None):
- prompt = messages[-1]["content"] if messages else ""
- return _FakeResponse(fake_llm_text(prompt))
- class _FakeChat(object):
- def __init__(self):
- self.completions = _FakeCompletions()
- class FakeLLMClient(object):
- def __init__(self):
- self.chat = _FakeChat()
- def fake_llm_text(prompt):
- lower = (prompt or "").lower()
- if "json" in lower:
- return (
- '{"project_name": "sample bidding project", '
- '"budget_amount": "sample budget", '
- '"evaluation_method": "sample evaluation method"}'
- )
- if "0-1" in prompt or "0-1" in lower or "score" in lower:
- return "0.8"
- if "category" in lower or "classify" in lower:
- return "budget"
- if "simple" in lower and "complex" in lower:
- return "simple"
- if "sub" in lower or "variant" in lower:
- return "project budget\nqualification requirements\ndelivery time"
- if "step" in lower or "general" in lower:
- return "general bidding project information"
- return "Sample offline LLM answer for the standalone RAG test."
- class FakeRerankModel(object):
- def compute_score(self, pairs):
- scores = []
- for query, text in pairs:
- query_words = set(_tokens(query))
- text_words = set(_tokens(text))
- scores.append(float(len(query_words & text_words)) + 0.1)
- return scores
- def build_vector_rag(rag_cls, **kwargs):
- options = {
- "embedding_model": FakeEmbedding(),
- "vector_store": SimpleVectorStore(),
- "llm_client": FakeLLMClient(),
- }
- options.update(kwargs)
- return rag_cls(**options)
- def print_results(method_name, query, results, elapsed):
- print("\nQuery: {}".format(query))
- print("-" * 60)
- print(" Retrieved {} documents in {:.4f}s".format(len(results), elapsed))
- for i, (doc, score) in enumerate(results, 1):
- title = doc.metadata.get("title", doc.metadata.get("source", "Unknown"))
- preview = doc.page_content[:100].replace("\n", " ")
- print(" [{}] {} (Score: {:.4f})".format(i, title, float(score)))
- print(" Preview: {}...".format(preview))
- def jinrun_retrieval_test(method_name, rag, index_func=None, queries=None, k=3):
- print("=" * 60)
- print("{} - Standalone Retrieval Test".format(method_name))
- print("=" * 60)
- documents = make_documents()
- print("\n[1/2] Preparing documents...")
- print(" Prepared {} documents".format(len(documents)))
- print("\n[2/2] Building index...")
- if index_func is None:
- rag.index_documents(documents)
- else:
- index_func(rag, documents)
- print(" Index built successfully")
- print("\n" + "=" * 60)
- print("{} Retrieval Test Results".format(method_name))
- print("=" * 60)
- for query in queries or TEST_QUERIES:
- start = time.time()
- results = rag.retrieve(query, k=k)
- elapsed = time.time() - start
- print_results(method_name, query, results, elapsed)
- print("\n{} test complete!".format(method_name))
- run_retrieval_test = jinrun_retrieval_test
- def run_html_tree_test(rag_cls):
- sample_html = """
- <html><body>
- <h1>Sample Medical Equipment Procurement Bidding Announcement</h1>
- <h2>Project Overview</h2>
- <p>Project budget: 28,000,000 RMB</p>
- <p>Project code: XX-HOSP-2024-015</p>
- <h2>Qualification Requirements</h2>
- <p>Must have Medical Device Operation License and ISO 13485 certification.</p>
- <h2>Delivery and Warranty</h2>
- <p>Delivery time: within 90 calendar days. Warranty period: minimum 5 years.</p>
- <h2>Payment Terms</h2>
- <p>100% payment after installation, debugging and acceptance.</p>
- </body></html>
- """
- print("=" * 60)
- print("{} - Standalone HTML Tree Test".format(rag_cls.__name__))
- print("=" * 60)
- rag = rag_cls()
- rag.build_index(sample_html)
- print(" HTML tree index built successfully")
- for query in ["project budget", "qualification requirements", "delivery time", "payment terms"]:
- start = time.time()
- results = rag.query(query, k=3)
- elapsed = time.time() - start
- print_results(rag_cls.__name__, query, results, elapsed)
- print("\n{} test complete!".format(rag_cls.__name__))
|