import os import json import pickle from abc import ABC, abstractmethod from typing import List, Dict, Tuple, Any, Optional import numpy as np from loguru import logger from .document_processor import Document class BaseVectorStore(ABC): @abstractmethod def add_documents(self, documents, embeddings): pass @abstractmethod def similarity_search(self, query_embedding, k=10): pass @abstractmethod def save(self, path): pass @abstractmethod def load(self, path): pass class FAISSStore(BaseVectorStore): def __init__(self, embedding_model=None): import faiss self.faiss = faiss self.index = None self.documents = [] self.embedding_model = embedding_model self.dimension = 0 def add_documents(self, documents, embeddings=None): if embeddings is None and self.embedding_model: texts = [doc.page_content for doc in documents] embeddings = self.embedding_model.embed_documents(texts) embeddings_np = np.array(embeddings, dtype=np.float32) self.dimension = embeddings_np.shape[1] if self.index is None: self.index = self.faiss.IndexFlatIP(self.dimension) self.faiss.normalize_L2(embeddings_np) self.index.add(embeddings_np) self.documents.extend(documents) logger.info("Added {} documents to FAISS index, total: {}".format(len(documents), len(self.documents))) def similarity_search(self, query_embedding, k=10): if self.index is None: return [] k = min(k, len(self.documents)) query_np = np.array([query_embedding], dtype=np.float32) self.faiss.normalize_L2(query_np) scores, indices = self.index.search(query_np, k) results = [] for score, idx in zip(scores[0], indices[0]): if idx != -1: results.append((self.documents[idx], float(score))) return results def save(self, path): os.makedirs(path, exist_ok=True) self.faiss.write_index(self.index, os.path.join(path, "faiss.index")) with open(os.path.join(path, "documents.pkl"), "wb") as f: pickle.dump(self.documents, f) logger.info("Saved FAISS index to {}".format(path)) def load(self, path): self.index = self.faiss.read_index(os.path.join(path, "faiss.index")) with open(os.path.join(path, "documents.pkl"), "rb") as f: self.documents = pickle.load(f) self.dimension = self.index.d logger.info("Loaded FAISS index from {}, documents: {}".format(path, len(self.documents))) class ChromaStore(BaseVectorStore): def __init__(self, embedding_model=None, collection_name="bidi_collection"): import chromadb self.chroma = chromadb self.client = None self.collection = None self.collection_name = collection_name self.embedding_model = embedding_model self.documents = [] def _ensure_client(self): if self.client is None: self.client = self.chroma.Client() def add_documents(self, documents, embeddings=None): self._ensure_client() if embeddings is None and self.embedding_model: texts = [doc.page_content for doc in documents] embeddings = self.embedding_model.embed_documents(texts) if self.collection is None: if self.embedding_model: self.collection = self.client.create_collection( name=self.collection_name, embedding_function=self._embedding_wrapper(), ) else: self.collection = self.client.create_collection(name=self.collection_name) ids = ["doc_{}".format(i + len(self.documents)) for i in range(len(documents))] texts = [doc.page_content for doc in documents] metadatas = [doc.metadata for doc in documents] self.collection.add( ids=ids, documents=texts, embeddings=embeddings, metadatas=metadatas, ) self.documents.extend(documents) logger.info("Added {} documents to Chroma, total: {}".format(len(documents), len(self.documents))) def similarity_search(self, query_embedding, k=10): if self.collection is None: return [] k = min(k, len(self.documents)) results = self.collection.query( query_embeddings=[query_embedding], n_results=k, include=["documents", "metadatas", "distances"], ) docs_returned = [] for doc_text, metadata, distance in zip( results["documents"][0], results["metadatas"][0], results["distances"][0] ): doc = Document(page_content=doc_text, metadata=metadata) score = 1.0 - distance docs_returned.append((doc, score)) return docs_returned def _embedding_wrapper(self): class EmbeddingFunc(object): def __init__(self, model): self.model = model def __call__(self, input): return self.model.embed_documents(input) return EmbeddingFunc(self.embedding_model) def save(self, path): logger.info("Chroma auto-persists data, no explicit save needed for path: {}".format(path)) def load(self, path): self._ensure_client() try: self.collection = self.client.get_collection(name=self.collection_name) self.documents = [ Document(page_content="", metadata={}) for _ in range(len(self.collection.get()["ids"])) ] logger.info("Loaded Chroma collection: {}".format(self.collection_name)) except Exception as e: logger.warning("Chroma collection not found: {}".format(e)) def get_vector_store(store_type="faiss", **kwargs): store_map = { "faiss": FAISSStore, "chroma": ChromaStore, } cls = store_map.get(store_type.lower()) if not cls: raise ValueError("Unknown vector store type: {}".format(store_type)) return cls(**kwargs)