from abc import ABC, abstractmethod from typing import List import numpy as np from loguru import logger class BaseEmbedding(ABC): @abstractmethod def embed_documents(self, texts): pass @abstractmethod def embed_query(self, text): pass class SentenceTransformerEmbedding(BaseEmbedding): def __init__(self, model_name="BAAI/bge-large-zh-v1.5", device="cpu"): from sentence_transformers import SentenceTransformer self.model = SentenceTransformer(model_name, device=device) self.dimension = self.model.get_sentence_embedding_dimension() logger.info("Loaded SentenceTransformer: {}, dim={}".format(model_name, self.dimension)) def embed_documents(self, texts): embeddings = self.model.encode(texts, batch_size=32, show_progress_bar=False) return embeddings.tolist() def embed_query(self, text): embedding = self.model.encode([text])[0] return embedding.tolist() class OpenAIEmbedding(BaseEmbedding): def __init__(self, model_name="text-embedding-3-large", api_key=None, base_url=None): from openai import OpenAI self.model_name = model_name self.client = OpenAI(api_key=api_key, base_url=base_url) self.dimension = {"text-embedding-3-large": 3072, "text-embedding-3-small": 1536}.get(model_name, 1536) logger.info("Loaded OpenAI Embedding: {}, dim={}".format(model_name, self.dimension)) def embed_documents(self, texts): embeddings = [] batch_size = 100 for i in range(0, len(texts), batch_size): batch = texts[i:i + batch_size] response = self.client.embeddings.create(model=self.model_name, input=batch) embeddings.extend([d.embedding for d in response.data]) return embeddings def embed_query(self, text): response = self.client.embeddings.create(model=self.model_name, input=[text]) return response.data[0].embedding class DashScopeEmbedding(BaseEmbedding): def __init__(self, model_name="text-embedding-v2", api_key=None): import dashscope self.model_name = model_name dashscope.api_key = api_key self.dimension = 1536 logger.info("Loaded DashScope Embedding: {}".format(model_name)) def embed_documents(self, texts): from dashscope import TextEmbedding embeddings = [] batch_size = 25 for i in range(0, len(texts), batch_size): batch = texts[i:i + batch_size] response = TextEmbedding.call(model=self.model_name, input=batch) if response.status_code == 200: for item in response.output["embeddings"]: embeddings.append(item["embedding"]) else: raise RuntimeError("DashScope embedding failed: {}".format(response.message)) return embeddings def embed_query(self, text): return self.embed_documents([text])[0] class ZhipuEmbedding(BaseEmbedding): def __init__(self, model_name="embedding-3", api_key=None): from zhipuai import ZhipuAI self.model_name = model_name self.client = ZhipuAI(api_key=api_key) self.dimension = 2048 logger.info("Loaded Zhipu Embedding: {}".format(model_name)) def embed_documents(self, texts): embeddings = [] for text in texts: response = self.client.embeddings.create(model=self.model_name, input=[text]) embeddings.append(response.data[0].embedding) return embeddings def embed_query(self, text): return self.embed_documents([text])[0] def get_embedding(model_type="sentence_transformer", **kwargs): model_map = { "sentence_transformer": SentenceTransformerEmbedding, "openai": OpenAIEmbedding, "dashscope": DashScopeEmbedding, "zhipu": ZhipuEmbedding, } cls = model_map.get(model_type) if not cls: raise ValueError("Unknown embedding model type: {}".format(model_type)) return cls(**kwargs)