embedding_models.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. from abc import ABC, abstractmethod
  2. from typing import List
  3. import numpy as np
  4. from loguru import logger
  5. class BaseEmbedding(ABC):
  6. @abstractmethod
  7. def embed_documents(self, texts):
  8. pass
  9. @abstractmethod
  10. def embed_query(self, text):
  11. pass
  12. class SentenceTransformerEmbedding(BaseEmbedding):
  13. def __init__(self, model_name="BAAI/bge-large-zh-v1.5", device="cpu"):
  14. from sentence_transformers import SentenceTransformer
  15. self.model = SentenceTransformer(model_name, device=device)
  16. self.dimension = self.model.get_sentence_embedding_dimension()
  17. logger.info("Loaded SentenceTransformer: {}, dim={}".format(model_name, self.dimension))
  18. def embed_documents(self, texts):
  19. embeddings = self.model.encode(texts, batch_size=32, show_progress_bar=False)
  20. return embeddings.tolist()
  21. def embed_query(self, text):
  22. embedding = self.model.encode([text])[0]
  23. return embedding.tolist()
  24. class OpenAIEmbedding(BaseEmbedding):
  25. def __init__(self, model_name="text-embedding-3-large", api_key=None, base_url=None):
  26. from openai import OpenAI
  27. self.model_name = model_name
  28. self.client = OpenAI(api_key=api_key, base_url=base_url)
  29. self.dimension = {"text-embedding-3-large": 3072, "text-embedding-3-small": 1536}.get(model_name, 1536)
  30. logger.info("Loaded OpenAI Embedding: {}, dim={}".format(model_name, self.dimension))
  31. def embed_documents(self, texts):
  32. embeddings = []
  33. batch_size = 100
  34. for i in range(0, len(texts), batch_size):
  35. batch = texts[i:i + batch_size]
  36. response = self.client.embeddings.create(model=self.model_name, input=batch)
  37. embeddings.extend([d.embedding for d in response.data])
  38. return embeddings
  39. def embed_query(self, text):
  40. response = self.client.embeddings.create(model=self.model_name, input=[text])
  41. return response.data[0].embedding
  42. class DashScopeEmbedding(BaseEmbedding):
  43. def __init__(self, model_name="text-embedding-v2", api_key=None):
  44. import dashscope
  45. self.model_name = model_name
  46. dashscope.api_key = api_key
  47. self.dimension = 1536
  48. logger.info("Loaded DashScope Embedding: {}".format(model_name))
  49. def embed_documents(self, texts):
  50. from dashscope import TextEmbedding
  51. embeddings = []
  52. batch_size = 25
  53. for i in range(0, len(texts), batch_size):
  54. batch = texts[i:i + batch_size]
  55. response = TextEmbedding.call(model=self.model_name, input=batch)
  56. if response.status_code == 200:
  57. for item in response.output["embeddings"]:
  58. embeddings.append(item["embedding"])
  59. else:
  60. raise RuntimeError("DashScope embedding failed: {}".format(response.message))
  61. return embeddings
  62. def embed_query(self, text):
  63. return self.embed_documents([text])[0]
  64. class ZhipuEmbedding(BaseEmbedding):
  65. def __init__(self, model_name="embedding-3", api_key=None):
  66. from zhipuai import ZhipuAI
  67. self.model_name = model_name
  68. self.client = ZhipuAI(api_key=api_key)
  69. self.dimension = 2048
  70. logger.info("Loaded Zhipu Embedding: {}".format(model_name))
  71. def embed_documents(self, texts):
  72. embeddings = []
  73. for text in texts:
  74. response = self.client.embeddings.create(model=self.model_name, input=[text])
  75. embeddings.append(response.data[0].embedding)
  76. return embeddings
  77. def embed_query(self, text):
  78. return self.embed_documents([text])[0]
  79. def get_embedding(model_type="sentence_transformer", **kwargs):
  80. model_map = {
  81. "sentence_transformer": SentenceTransformerEmbedding,
  82. "openai": OpenAIEmbedding,
  83. "dashscope": DashScopeEmbedding,
  84. "zhipu": ZhipuEmbedding,
  85. }
  86. cls = model_map.get(model_type)
  87. if not cls:
  88. raise ValueError("Unknown embedding model type: {}".format(model_type))
  89. return cls(**kwargs)