| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181 |
- import os
- import json
- import pickle
- from abc import ABC, abstractmethod
- from typing import List, Dict, Tuple, Any, Optional
- import numpy as np
- from loguru import logger
- from .document_processor import Document
- class BaseVectorStore(ABC):
- @abstractmethod
- def add_documents(self, documents, embeddings):
- pass
- @abstractmethod
- def similarity_search(self, query_embedding, k=10):
- pass
- @abstractmethod
- def save(self, path):
- pass
- @abstractmethod
- def load(self, path):
- pass
- class FAISSStore(BaseVectorStore):
- def __init__(self, embedding_model=None):
- import faiss
- self.faiss = faiss
- self.index = None
- self.documents = []
- self.embedding_model = embedding_model
- self.dimension = 0
- def add_documents(self, documents, embeddings=None):
- if embeddings is None and self.embedding_model:
- texts = [doc.page_content for doc in documents]
- embeddings = self.embedding_model.embed_documents(texts)
- embeddings_np = np.array(embeddings, dtype=np.float32)
- self.dimension = embeddings_np.shape[1]
- if self.index is None:
- self.index = self.faiss.IndexFlatIP(self.dimension)
- self.faiss.normalize_L2(embeddings_np)
- self.index.add(embeddings_np)
- self.documents.extend(documents)
- logger.info("Added {} documents to FAISS index, total: {}".format(len(documents), len(self.documents)))
- def similarity_search(self, query_embedding, k=10):
- if self.index is None:
- return []
-
- k = min(k, len(self.documents))
- query_np = np.array([query_embedding], dtype=np.float32)
- self.faiss.normalize_L2(query_np)
- scores, indices = self.index.search(query_np, k)
-
- results = []
- for score, idx in zip(scores[0], indices[0]):
- if idx != -1:
- results.append((self.documents[idx], float(score)))
- return results
- def save(self, path):
- os.makedirs(path, exist_ok=True)
- self.faiss.write_index(self.index, os.path.join(path, "faiss.index"))
- with open(os.path.join(path, "documents.pkl"), "wb") as f:
- pickle.dump(self.documents, f)
- logger.info("Saved FAISS index to {}".format(path))
- def load(self, path):
- self.index = self.faiss.read_index(os.path.join(path, "faiss.index"))
- with open(os.path.join(path, "documents.pkl"), "rb") as f:
- self.documents = pickle.load(f)
- self.dimension = self.index.d
- logger.info("Loaded FAISS index from {}, documents: {}".format(path, len(self.documents)))
- class ChromaStore(BaseVectorStore):
- def __init__(self, embedding_model=None, collection_name="bidi_collection"):
- import chromadb
- self.chroma = chromadb
- self.client = None
- self.collection = None
- self.collection_name = collection_name
- self.embedding_model = embedding_model
- self.documents = []
- def _ensure_client(self):
- if self.client is None:
- self.client = self.chroma.Client()
- def add_documents(self, documents, embeddings=None):
- self._ensure_client()
-
- if embeddings is None and self.embedding_model:
- texts = [doc.page_content for doc in documents]
- embeddings = self.embedding_model.embed_documents(texts)
- if self.collection is None:
- if self.embedding_model:
- self.collection = self.client.create_collection(
- name=self.collection_name,
- embedding_function=self._embedding_wrapper(),
- )
- else:
- self.collection = self.client.create_collection(name=self.collection_name)
- ids = ["doc_{}".format(i + len(self.documents)) for i in range(len(documents))]
- texts = [doc.page_content for doc in documents]
- metadatas = [doc.metadata for doc in documents]
- self.collection.add(
- ids=ids,
- documents=texts,
- embeddings=embeddings,
- metadatas=metadatas,
- )
- self.documents.extend(documents)
- logger.info("Added {} documents to Chroma, total: {}".format(len(documents), len(self.documents)))
- def similarity_search(self, query_embedding, k=10):
- if self.collection is None:
- return []
- k = min(k, len(self.documents))
- results = self.collection.query(
- query_embeddings=[query_embedding],
- n_results=k,
- include=["documents", "metadatas", "distances"],
- )
- docs_returned = []
- for doc_text, metadata, distance in zip(
- results["documents"][0], results["metadatas"][0], results["distances"][0]
- ):
- doc = Document(page_content=doc_text, metadata=metadata)
- score = 1.0 - distance
- docs_returned.append((doc, score))
- return docs_returned
- def _embedding_wrapper(self):
- class EmbeddingFunc(object):
- def __init__(self, model):
- self.model = model
- def __call__(self, input):
- return self.model.embed_documents(input)
- return EmbeddingFunc(self.embedding_model)
- def save(self, path):
- logger.info("Chroma auto-persists data, no explicit save needed for path: {}".format(path))
- def load(self, path):
- self._ensure_client()
- try:
- self.collection = self.client.get_collection(name=self.collection_name)
- self.documents = [
- Document(page_content="", metadata={})
- for _ in range(len(self.collection.get()["ids"]))
- ]
- logger.info("Loaded Chroma collection: {}".format(self.collection_name))
- except Exception as e:
- logger.warning("Chroma collection not found: {}".format(e))
- def get_vector_store(store_type="faiss", **kwargs):
- store_map = {
- "faiss": FAISSStore,
- "chroma": ChromaStore,
- }
- cls = store_map.get(store_type.lower())
- if not cls:
- raise ValueError("Unknown vector store type: {}".format(store_type))
- return cls(**kwargs)
|