vector_stores.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. import os
  2. import json
  3. import pickle
  4. from abc import ABC, abstractmethod
  5. from typing import List, Dict, Tuple, Any, Optional
  6. import numpy as np
  7. from loguru import logger
  8. from .document_processor import Document
  9. class BaseVectorStore(ABC):
  10. @abstractmethod
  11. def add_documents(self, documents, embeddings):
  12. pass
  13. @abstractmethod
  14. def similarity_search(self, query_embedding, k=10):
  15. pass
  16. @abstractmethod
  17. def save(self, path):
  18. pass
  19. @abstractmethod
  20. def load(self, path):
  21. pass
  22. class FAISSStore(BaseVectorStore):
  23. def __init__(self, embedding_model=None):
  24. import faiss
  25. self.faiss = faiss
  26. self.index = None
  27. self.documents = []
  28. self.embedding_model = embedding_model
  29. self.dimension = 0
  30. def add_documents(self, documents, embeddings=None):
  31. if embeddings is None and self.embedding_model:
  32. texts = [doc.page_content for doc in documents]
  33. embeddings = self.embedding_model.embed_documents(texts)
  34. embeddings_np = np.array(embeddings, dtype=np.float32)
  35. self.dimension = embeddings_np.shape[1]
  36. if self.index is None:
  37. self.index = self.faiss.IndexFlatIP(self.dimension)
  38. self.faiss.normalize_L2(embeddings_np)
  39. self.index.add(embeddings_np)
  40. self.documents.extend(documents)
  41. logger.info("Added {} documents to FAISS index, total: {}".format(len(documents), len(self.documents)))
  42. def similarity_search(self, query_embedding, k=10):
  43. if self.index is None:
  44. return []
  45. k = min(k, len(self.documents))
  46. query_np = np.array([query_embedding], dtype=np.float32)
  47. self.faiss.normalize_L2(query_np)
  48. scores, indices = self.index.search(query_np, k)
  49. results = []
  50. for score, idx in zip(scores[0], indices[0]):
  51. if idx != -1:
  52. results.append((self.documents[idx], float(score)))
  53. return results
  54. def save(self, path):
  55. os.makedirs(path, exist_ok=True)
  56. self.faiss.write_index(self.index, os.path.join(path, "faiss.index"))
  57. with open(os.path.join(path, "documents.pkl"), "wb") as f:
  58. pickle.dump(self.documents, f)
  59. logger.info("Saved FAISS index to {}".format(path))
  60. def load(self, path):
  61. self.index = self.faiss.read_index(os.path.join(path, "faiss.index"))
  62. with open(os.path.join(path, "documents.pkl"), "rb") as f:
  63. self.documents = pickle.load(f)
  64. self.dimension = self.index.d
  65. logger.info("Loaded FAISS index from {}, documents: {}".format(path, len(self.documents)))
  66. class ChromaStore(BaseVectorStore):
  67. def __init__(self, embedding_model=None, collection_name="bidi_collection"):
  68. import chromadb
  69. self.chroma = chromadb
  70. self.client = None
  71. self.collection = None
  72. self.collection_name = collection_name
  73. self.embedding_model = embedding_model
  74. self.documents = []
  75. def _ensure_client(self):
  76. if self.client is None:
  77. self.client = self.chroma.Client()
  78. def add_documents(self, documents, embeddings=None):
  79. self._ensure_client()
  80. if embeddings is None and self.embedding_model:
  81. texts = [doc.page_content for doc in documents]
  82. embeddings = self.embedding_model.embed_documents(texts)
  83. if self.collection is None:
  84. if self.embedding_model:
  85. self.collection = self.client.create_collection(
  86. name=self.collection_name,
  87. embedding_function=self._embedding_wrapper(),
  88. )
  89. else:
  90. self.collection = self.client.create_collection(name=self.collection_name)
  91. ids = ["doc_{}".format(i + len(self.documents)) for i in range(len(documents))]
  92. texts = [doc.page_content for doc in documents]
  93. metadatas = [doc.metadata for doc in documents]
  94. self.collection.add(
  95. ids=ids,
  96. documents=texts,
  97. embeddings=embeddings,
  98. metadatas=metadatas,
  99. )
  100. self.documents.extend(documents)
  101. logger.info("Added {} documents to Chroma, total: {}".format(len(documents), len(self.documents)))
  102. def similarity_search(self, query_embedding, k=10):
  103. if self.collection is None:
  104. return []
  105. k = min(k, len(self.documents))
  106. results = self.collection.query(
  107. query_embeddings=[query_embedding],
  108. n_results=k,
  109. include=["documents", "metadatas", "distances"],
  110. )
  111. docs_returned = []
  112. for doc_text, metadata, distance in zip(
  113. results["documents"][0], results["metadatas"][0], results["distances"][0]
  114. ):
  115. doc = Document(page_content=doc_text, metadata=metadata)
  116. score = 1.0 - distance
  117. docs_returned.append((doc, score))
  118. return docs_returned
  119. def _embedding_wrapper(self):
  120. class EmbeddingFunc(object):
  121. def __init__(self, model):
  122. self.model = model
  123. def __call__(self, input):
  124. return self.model.embed_documents(input)
  125. return EmbeddingFunc(self.embedding_model)
  126. def save(self, path):
  127. logger.info("Chroma auto-persists data, no explicit save needed for path: {}".format(path))
  128. def load(self, path):
  129. self._ensure_client()
  130. try:
  131. self.collection = self.client.get_collection(name=self.collection_name)
  132. self.documents = [
  133. Document(page_content="", metadata={})
  134. for _ in range(len(self.collection.get()["ids"]))
  135. ]
  136. logger.info("Loaded Chroma collection: {}".format(self.collection_name))
  137. except Exception as e:
  138. logger.warning("Chroma collection not found: {}".format(e))
  139. def get_vector_store(store_type="faiss", **kwargs):
  140. store_map = {
  141. "faiss": FAISSStore,
  142. "chroma": ChromaStore,
  143. }
  144. cls = store_map.get(store_type.lower())
  145. if not cls:
  146. raise ValueError("Unknown vector store type: {}".format(store_type))
  147. return cls(**kwargs)