bm25_backend.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. # -*- coding: utf-8 -*-
  2. """BM25 backend selection with a small local fallback."""
  3. import math
  4. from collections import Counter
  5. import numpy as np
  6. class SimpleBM25Okapi(object):
  7. """Small BM25Okapi-compatible fallback used when rank_bm25 is unavailable."""
  8. def __init__(self, corpus, k1=1.5, b=0.75, epsilon=0.25):
  9. self.corpus = [list(doc or []) for doc in corpus]
  10. self.k1 = k1
  11. self.b = b
  12. self.epsilon = epsilon
  13. self.corpus_size = len(self.corpus)
  14. self.doc_len = [len(doc) for doc in self.corpus]
  15. self.avgdl = float(sum(self.doc_len)) / self.corpus_size if self.corpus_size else 0.0
  16. self.doc_freqs = [Counter(doc) for doc in self.corpus]
  17. self.idf = self._calc_idf()
  18. def _calc_idf(self):
  19. nd = {}
  20. for freqs in self.doc_freqs:
  21. for word in freqs:
  22. nd[word] = nd.get(word, 0) + 1
  23. idf = {}
  24. negative_idfs = []
  25. for word, freq in nd.items():
  26. value = math.log(self.corpus_size - freq + 0.5) - math.log(freq + 0.5)
  27. idf[word] = value
  28. if value < 0:
  29. negative_idfs.append(value)
  30. average_idf = sum(idf.values()) / len(idf) if idf else 0.0
  31. eps = self.epsilon * average_idf
  32. for word in idf:
  33. if idf[word] < 0:
  34. idf[word] = eps
  35. return idf
  36. def get_scores(self, query):
  37. scores = np.zeros(self.corpus_size)
  38. if not query or not self.corpus_size or self.avgdl <= 0:
  39. return scores
  40. for token in query:
  41. token_idf = self.idf.get(token)
  42. if token_idf is None:
  43. continue
  44. for i, freqs in enumerate(self.doc_freqs):
  45. freq = freqs.get(token, 0)
  46. if freq == 0:
  47. continue
  48. denominator = freq + self.k1 * (1 - self.b + self.b * self.doc_len[i] / self.avgdl)
  49. scores[i] += token_idf * freq * (self.k1 + 1) / denominator
  50. return scores
  51. def get_bm25_okapi():
  52. try:
  53. from rank_bm25 import BM25Okapi
  54. return BM25Okapi
  55. except ImportError:
  56. return SimpleBM25Okapi