| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110 |
- from abc import ABC, abstractmethod
- from typing import List
- import numpy as np
- from loguru import logger
- class BaseEmbedding(ABC):
- @abstractmethod
- def embed_documents(self, texts):
- pass
- @abstractmethod
- def embed_query(self, text):
- pass
- class SentenceTransformerEmbedding(BaseEmbedding):
- def __init__(self, model_name="BAAI/bge-large-zh-v1.5", device="cpu"):
- from sentence_transformers import SentenceTransformer
- self.model = SentenceTransformer(model_name, device=device)
- self.dimension = self.model.get_sentence_embedding_dimension()
- logger.info("Loaded SentenceTransformer: {}, dim={}".format(model_name, self.dimension))
- def embed_documents(self, texts):
- embeddings = self.model.encode(texts, batch_size=32, show_progress_bar=False)
- return embeddings.tolist()
- def embed_query(self, text):
- embedding = self.model.encode([text])[0]
- return embedding.tolist()
- class OpenAIEmbedding(BaseEmbedding):
- def __init__(self, model_name="text-embedding-3-large", api_key=None, base_url=None):
- from openai import OpenAI
- self.model_name = model_name
- self.client = OpenAI(api_key=api_key, base_url=base_url)
- self.dimension = {"text-embedding-3-large": 3072, "text-embedding-3-small": 1536}.get(model_name, 1536)
- logger.info("Loaded OpenAI Embedding: {}, dim={}".format(model_name, self.dimension))
- def embed_documents(self, texts):
- embeddings = []
- batch_size = 100
- for i in range(0, len(texts), batch_size):
- batch = texts[i:i + batch_size]
- response = self.client.embeddings.create(model=self.model_name, input=batch)
- embeddings.extend([d.embedding for d in response.data])
- return embeddings
- def embed_query(self, text):
- response = self.client.embeddings.create(model=self.model_name, input=[text])
- return response.data[0].embedding
- class DashScopeEmbedding(BaseEmbedding):
- def __init__(self, model_name="text-embedding-v2", api_key=None):
- import dashscope
- self.model_name = model_name
- dashscope.api_key = api_key
- self.dimension = 1536
- logger.info("Loaded DashScope Embedding: {}".format(model_name))
- def embed_documents(self, texts):
- from dashscope import TextEmbedding
- embeddings = []
- batch_size = 25
- for i in range(0, len(texts), batch_size):
- batch = texts[i:i + batch_size]
- response = TextEmbedding.call(model=self.model_name, input=batch)
- if response.status_code == 200:
- for item in response.output["embeddings"]:
- embeddings.append(item["embedding"])
- else:
- raise RuntimeError("DashScope embedding failed: {}".format(response.message))
- return embeddings
- def embed_query(self, text):
- return self.embed_documents([text])[0]
- class ZhipuEmbedding(BaseEmbedding):
- def __init__(self, model_name="embedding-3", api_key=None):
- from zhipuai import ZhipuAI
- self.model_name = model_name
- self.client = ZhipuAI(api_key=api_key)
- self.dimension = 2048
- logger.info("Loaded Zhipu Embedding: {}".format(model_name))
- def embed_documents(self, texts):
- embeddings = []
- for text in texts:
- response = self.client.embeddings.create(model=self.model_name, input=[text])
- embeddings.append(response.data[0].embedding)
- return embeddings
- def embed_query(self, text):
- return self.embed_documents([text])[0]
- def get_embedding(model_type="sentence_transformer", **kwargs):
- model_map = {
- "sentence_transformer": SentenceTransformerEmbedding,
- "openai": OpenAIEmbedding,
- "dashscope": DashScopeEmbedding,
- "zhipu": ZhipuEmbedding,
- }
- cls = model_map.get(model_type)
- if not cls:
- raise ValueError("Unknown embedding model type: {}".format(model_type))
- return cls(**kwargs)
|