norm.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553
  1. #!/usr/bin/env python
  2. # -*- Mode: Python; tab-width: 4; indent-tabs-mode: nil; coding: utf-8; -*-
  3. # vim:set ft=python ts=4 sw=4 sts=4 autoindent:
  4. """Normalization support."""
  5. from datetime import datetime
  6. from functools import reduce
  7. import normdb
  8. import sdistance
  9. import simstringdb
  10. from document import real_directory
  11. from message import Messager
  12. from normdb import string_norm_form
  13. from projectconfig import ProjectConfiguration
  14. # whether to display alignment scores in search result table
  15. DISPLAY_SEARCH_SCORES = False
  16. # maximum alignment score (tsuruoka_local)
  17. MAX_SCORE = 1000
  18. # maximum alignment score (tsuruoka_local) difference allowed between
  19. # the score for a string s and the best known score before excluding s
  20. # from consideration
  21. MAX_DIFF_TO_BEST_SCORE = 200
  22. # maximum number of search results to return
  23. MAX_SEARCH_RESULT_NUMBER = 1000
  24. NORM_LOOKUP_DEBUG = True
  25. REPORT_LOOKUP_TIMINGS = False
  26. # debugging
  27. def _check_DB_version(database):
  28. # TODO; not implemented yet for new-style SQL DBs.
  29. pass
  30. def _report_timings(dbname, start, msg=None):
  31. delta = datetime.now() - start
  32. strdelta = str(delta).replace('0:00:0', '') # take out zero min & hour
  33. queries = normdb.get_query_count(dbname)
  34. normdb.reset_query_count(dbname)
  35. Messager.info("Processed " + str(queries) + " queries in " + strdelta +
  36. (msg if msg is not None else ""))
  37. def _get_db_path(database, collection):
  38. if collection is None:
  39. # TODO: default to WORK_DIR config?
  40. return None
  41. else:
  42. try:
  43. conf_dir = real_directory(collection)
  44. projectconf = ProjectConfiguration(conf_dir)
  45. norm_conf = projectconf.get_normalization_config()
  46. for entry in norm_conf:
  47. dbname, dbpath = entry[0], entry[3]
  48. if dbname == database:
  49. return dbpath
  50. # not found in config.
  51. Messager.warning('DB ' + database + ' not defined in config for ' +
  52. collection + ', falling back on default.')
  53. return None
  54. except Exception:
  55. # whatever goes wrong, just warn and fall back on the default.
  56. Messager.warning('Failed to get DB path from config for ' +
  57. collection + ', falling back on default.')
  58. return None
  59. def norm_get_name(database, key, collection=None):
  60. if NORM_LOOKUP_DEBUG:
  61. _check_DB_version(database)
  62. if REPORT_LOOKUP_TIMINGS:
  63. lookup_start = datetime.now()
  64. dbpath = _get_db_path(database, collection)
  65. if dbpath is None:
  66. # full path not configured, fall back on name as default
  67. dbpath = database
  68. try:
  69. data = normdb.data_by_id(dbpath, key)
  70. except normdb.dbNotFoundError as e:
  71. Messager.warning(str(e))
  72. data = None
  73. # just grab the first one (sorry, this is a bit opaque)
  74. if data is not None:
  75. value = data[0][0][1]
  76. else:
  77. value = None
  78. if REPORT_LOOKUP_TIMINGS:
  79. _report_timings(database, lookup_start)
  80. # echo request for sync
  81. json_dic = {
  82. 'database': database,
  83. 'key': key,
  84. 'value': value
  85. }
  86. return json_dic
  87. def norm_get_data(database, key, collection=None):
  88. if NORM_LOOKUP_DEBUG:
  89. _check_DB_version(database)
  90. if REPORT_LOOKUP_TIMINGS:
  91. lookup_start = datetime.now()
  92. dbpath = _get_db_path(database, collection)
  93. if dbpath is None:
  94. # full path not configured, fall back on name as default
  95. dbpath = database
  96. try:
  97. data = normdb.data_by_id(dbpath, key)
  98. except normdb.dbNotFoundError as e:
  99. Messager.warning(str(e))
  100. data = None
  101. if data is None:
  102. Messager.warning("Failed to get data for " + database + ":" + key)
  103. if REPORT_LOOKUP_TIMINGS:
  104. _report_timings(database, lookup_start)
  105. # echo request for sync
  106. json_dic = {
  107. 'database': database,
  108. 'key': key,
  109. 'value': data
  110. }
  111. return json_dic
  112. # TODO: deprecated, confirm unnecessary and remove.
  113. # def norm_get_ids(database, name, collection=None):
  114. # if NORM_LOOKUP_DEBUG:
  115. # _check_DB_version(database)
  116. # if REPORT_LOOKUP_TIMINGS:
  117. # lookup_start = datetime.now()
  118. #
  119. # dbpath = _get_db_path(database, collection)
  120. # if dbpath is None:
  121. # # full path not configured, fall back on name as default
  122. # dbpath = database
  123. #
  124. # keys = normdb.ids_by_name(dbpath, name)
  125. #
  126. # if REPORT_LOOKUP_TIMINGS:
  127. # _report_timings(database, lookup_start)
  128. #
  129. # # echo request for sync
  130. # json_dic = {
  131. # 'database' : database,
  132. # 'value' : name,
  133. # 'keys' : keys,
  134. # }
  135. # return json_dic
  136. def _format_datas(datas, scores=None, matched=None):
  137. # helper for norm_search(), formats data from DB into a table
  138. # for client, sort by scores if given.
  139. if scores is None:
  140. scores = {}
  141. if matched is None:
  142. matched = {}
  143. # chop off all but the first two groups of label:value pairs for
  144. # each key; latter ones are assumed to be additional information
  145. # not intended for display of search results.
  146. # TODO: avoid the unnecessary queries for this information.
  147. cropped = {}
  148. for key in datas:
  149. cropped[key] = datas[key][:2]
  150. datas = cropped
  151. # organize into a table format with separate header and data
  152. # (this matches the collection browser data format)
  153. unique_labels = []
  154. seen_label = {}
  155. for key in datas:
  156. # check for dups within each entry
  157. seen_label_for_key = {}
  158. for i, group in enumerate(datas[key]):
  159. for label, value in group:
  160. if label not in seen_label:
  161. # store with group index to sort all labels by
  162. # group idx first
  163. unique_labels.append((i, label))
  164. seen_label[label] = True
  165. if label in seen_label_for_key:
  166. # too noisy, and not really harmful now that matching
  167. # values are preferred for repeated labels.
  168. # Messager.warning("Repeated label (%s) in normalization data not supported" % label)
  169. pass
  170. seen_label_for_key[label] = True
  171. # sort unique labels by group index (should be otherwise stable,
  172. # holds since python 2.3), and flatten
  173. unique_labels.sort(key=lambda a: a[0])
  174. unique_labels = [a[1] for a in unique_labels]
  175. # ID is first field, and datatype is "string" for all labels
  176. header = [(label, "string") for label in ["ID"] + unique_labels]
  177. if DISPLAY_SEARCH_SCORES:
  178. header += [("score", "int")]
  179. # construct items, sorted by score first, ID second (latter for stability)
  180. sorted_keys = sorted(list(datas.keys()), key=lambda a: (scores.get(a, 0), a), reverse=True)
  181. items = []
  182. for key in sorted_keys:
  183. # make dict for lookup. In case of duplicates (e.g. multiple
  184. # "synonym" entries), prefer ones that were matched.
  185. # TODO: prefer more exact matches when multiple found.
  186. data_dict = {}
  187. for group in datas[key]:
  188. for label, value in group:
  189. if label not in data_dict or (value in matched and
  190. data_dict[label] not in matched):
  191. data_dict[label] = value
  192. # construct item
  193. item = [str(key)]
  194. for label in unique_labels:
  195. if label in data_dict:
  196. item.append(data_dict[label])
  197. else:
  198. item.append('')
  199. if DISPLAY_SEARCH_SCORES:
  200. item += [str(scores.get(key))]
  201. items.append(item)
  202. return header, items
  203. def _norm_filter_score(score, best_score=MAX_SCORE):
  204. return score < best_score - MAX_DIFF_TO_BEST_SCORE
  205. # TODO: get rid of arbitrary max_cost default constant
  206. def _norm_score(substring, name, max_cost=500):
  207. # returns an integer score representing the similarity of the given
  208. # substring to the given name (larger is better).
  209. cache = _norm_score.__cache
  210. if (substring, name) not in cache:
  211. cost = sdistance.tsuruoka_local(substring, name, max_cost=max_cost)
  212. # debugging
  213. #Messager.info('%s --- %s: %d (max %d)' % (substring, name, cost, max_cost))
  214. score = MAX_SCORE - cost
  215. cache[(substring, name)] = score
  216. # TODO: should we avoid exceeding max_cost? Cached values might.
  217. return cache[(substring, name)]
  218. _norm_score.__cache = {}
  219. def _norm_search_name_attr(database, name, attr,
  220. matched, score_by_id, score_by_str,
  221. best_score=0, exactmatch=False,
  222. threshold=simstringdb.DEFAULT_THRESHOLD):
  223. # helper for norm_search, searches for matches where given name
  224. # appears either in full or as an approximate substring of a full
  225. # name (if exactmatch is False) in given DB. If attr is not None,
  226. # requires its value to appear as an attribute of the entry with
  227. # the matched name. Updates matched, score_by_id, and
  228. # score_by_str, returns best_score.
  229. # If there are no strict substring matches for a given attribute
  230. # in the simstring DB, we can be sure that no query can succeed,
  231. # and can fail early.
  232. # TODO: this would be more effective (as would some other things)
  233. # if the attributes were in a separate simstring DB from the
  234. # names.
  235. if attr is not None:
  236. normattr = string_norm_form(attr)
  237. if not simstringdb.ssdb_supstring_exists(normattr, database, 1.0):
  238. # debugging
  239. #Messager.info('Early norm search fail on "%s"' % attr)
  240. return best_score
  241. if exactmatch:
  242. # only candidate string is given name
  243. strs = [name]
  244. ss_norm_score = {string_norm_form(name): 1.0}
  245. else:
  246. # expand to substrings using simstring
  247. # simstring requires UTF-8
  248. normname = string_norm_form(name)
  249. str_scores = simstringdb.ssdb_supstring_lookup(normname, database,
  250. threshold, True)
  251. strs = [s[0] for s in str_scores]
  252. ss_norm_score = dict(str_scores)
  253. # TODO: recreate this older filter; watch out for which name to use!
  254. # # filter to strings not already considered
  255. # strs = [s for s in strs if (normname, s) not in score_by_str]
  256. # look up IDs
  257. if attr is None:
  258. id_names = normdb.ids_by_names(database, strs, False, True)
  259. else:
  260. id_names = normdb.ids_by_names_attr(database, strs, attr, False, True)
  261. # sort by simstring (n-gram overlap) score to prioritize likely
  262. # good hits.
  263. # TODO: this doesn't seem to be having a very significant effect.
  264. # consider removing as unnecessary complication (ss_norm_score also).
  265. id_name_scores = [(i, n, ss_norm_score[string_norm_form(n)])
  266. for i, n in id_names]
  267. id_name_scores.sort(key=lambda a: a[2], reverse=True)
  268. id_names = [(i, n) for i, n, s in id_name_scores]
  269. # update matches and scores
  270. for i, n in id_names:
  271. if n not in matched:
  272. matched[n] = set()
  273. matched[n].add(i)
  274. max_cost = MAX_SCORE - best_score + MAX_DIFF_TO_BEST_SCORE + 1
  275. if (name, n) not in score_by_str:
  276. # TODO: decide whether to use normalized or unnormalized strings
  277. # for scoring here.
  278. #score_by_str[(name, n)] = _norm_score(name, n, max_cost)
  279. score_by_str[(name, n)] = _norm_score(
  280. string_norm_form(name), string_norm_form(n), max_cost)
  281. score = score_by_str[(name, n)]
  282. best_score = max(score, best_score)
  283. score_by_id[i] = max(score_by_id.get(i, -1),
  284. score_by_str[(name, n)])
  285. # stop if max count reached
  286. if len(score_by_id) > MAX_SEARCH_RESULT_NUMBER:
  287. Messager.info(
  288. 'Note: more than %d search results, only retrieving top matches' %
  289. MAX_SEARCH_RESULT_NUMBER)
  290. break
  291. return best_score
  292. def _norm_search_impl(database, name, collection=None, exactmatch=False):
  293. if NORM_LOOKUP_DEBUG:
  294. _check_DB_version(database)
  295. if REPORT_LOOKUP_TIMINGS:
  296. lookup_start = datetime.now()
  297. dbpath = _get_db_path(database, collection)
  298. if dbpath is None:
  299. # full path not configured, fall back on name as default
  300. dbpath = database
  301. # maintain map from searched names to matching IDs and scores for
  302. # ranking
  303. matched = {}
  304. score_by_id = {}
  305. score_by_str = {}
  306. # look up hits where name appears in full
  307. best_score = _norm_search_name_attr(dbpath, name, None,
  308. matched, score_by_id, score_by_str,
  309. 0, exactmatch)
  310. # if there are no hits and we only have a simple candidate string,
  311. # look up with a low threshold
  312. if best_score == 0 and len(name.split()) == 1:
  313. best_score = _norm_search_name_attr(dbpath, name, None,
  314. matched, score_by_id, score_by_str,
  315. 0, exactmatch, 0.5)
  316. # if there are no good hits, also consider only part of the input
  317. # as name and the rest as an attribute.
  318. # TODO: reconsider arbitrary cutoff
  319. if best_score < 900 and not exactmatch:
  320. parts = name.split()
  321. # prioritize having the attribute after the name
  322. for i in range(len(parts) - 1, 0, -1):
  323. # TODO: this early termination is sub-optimal: it's not
  324. # possible to know in advance which way of splitting the
  325. # query into parts yields best results. Reconsider.
  326. if len(score_by_id) > MAX_SEARCH_RESULT_NUMBER:
  327. break
  328. start = ' '.join(parts[:i])
  329. end = ' '.join(parts[i:])
  330. # query both ways (start is name, end is attr and vice versa)
  331. best_score = _norm_search_name_attr(dbpath, start, end,
  332. matched, score_by_id,
  333. score_by_str,
  334. best_score, exactmatch)
  335. best_score = _norm_search_name_attr(dbpath, end, start,
  336. matched, score_by_id,
  337. score_by_str,
  338. best_score, exactmatch)
  339. # flatten to single set of IDs
  340. ids = reduce(set.union, list(matched.values()), set())
  341. # filter ids that now (after all queries complete) fail
  342. # TODO: are we sure that this is a good idea?
  343. ids = set([i for i in ids
  344. if not _norm_filter_score(score_by_id[i], best_score)])
  345. # TODO: avoid unnecessary queries: datas_by_ids queries for names,
  346. # attributes and infos, but _format_datas only uses the first two.
  347. datas = normdb.datas_by_ids(dbpath, ids)
  348. header, items = _format_datas(datas, score_by_id, matched)
  349. if REPORT_LOOKUP_TIMINGS:
  350. _report_timings(database, lookup_start,
  351. ", retrieved " + str(len(items)) + " items")
  352. # echo request for sync
  353. json_dic = {
  354. 'database': database,
  355. 'query': name,
  356. 'header': header,
  357. 'items': items,
  358. }
  359. return json_dic
  360. def norm_search(database, name, collection=None, exactmatch=False):
  361. try:
  362. return _norm_search_impl(database, name, collection, exactmatch)
  363. except simstringdb.ssdbNotFoundError as e:
  364. Messager.warning(str(e))
  365. return {
  366. 'database': database,
  367. 'query': name,
  368. 'header': [],
  369. 'items': []
  370. }
  371. def _test():
  372. # test
  373. test_cases = {
  374. 'UniProt': {
  375. 'Runx3': 'Q64131',
  376. 'Runx3 mouse': 'Q64131',
  377. 'Runx1': 'Q03347',
  378. 'Runx1 mouse': 'Q03347',
  379. 'Eomes': 'O54839',
  380. 'Eomes mouse': 'O54839',
  381. 'granzyme B': 'P04187',
  382. 'granzyme B mouse': 'P04187',
  383. 'INF-gamma': 'P01580',
  384. 'INF-gamma mouse': 'P01580',
  385. 'IL-2': 'P04351',
  386. 'IL-2 mouse': 'P04351',
  387. 'T-bet': 'Q9JKD8',
  388. 'T-bet mouse': 'Q9JKD8',
  389. 'GATA-1': 'P15976',
  390. 'GATA-1 human': 'P15976',
  391. 'Interleukin-10': 'P22301',
  392. 'Interleukin-10 human': 'P22301',
  393. 'Interleukin-12': 'P29459',
  394. 'Interleukin-12 human': 'P29459',
  395. 'interferon-gamma': 'P01579',
  396. 'interferon-gamma human': 'P01579',
  397. 'interferon gamma human': 'P01579',
  398. 'Fas ligand': 'P48023',
  399. 'Fas ligand human': 'P48023',
  400. 'IkappaB-alpha': 'P25963',
  401. 'IkappaB-alpha human': 'P25963',
  402. 'transforming growth factor (TGF)-beta1': 'P01137',
  403. 'transforming growth factor (TGF)-beta1 human': 'P01137',
  404. 'transforming growth factor beta1 human': 'P01137',
  405. 'tumor necrosis factor alpha': 'P01375',
  406. 'tumor necrosis factor alpha human': 'P01375',
  407. 'Epstein-Barr virus latent membrane protein LMP1': 'Q1HVB3',
  408. 'TATA box binding protein': 'P20226',
  409. 'TATA box binding protein human': 'P20226',
  410. 'HIV protease': '??????', # TODO
  411. # TODO
  412. 'human immunodeficiency virus type 1 (HIV) protease': '??????',
  413. }
  414. }
  415. overall_start = datetime.now()
  416. query_count, hit_count = 0, 0
  417. misses = []
  418. for DB in test_cases:
  419. for query in test_cases[DB]:
  420. target = test_cases[DB][query]
  421. start = datetime.now()
  422. results = norm_search(DB, query)
  423. delta = datetime.now() - start
  424. found = False
  425. found_rank = -1
  426. for rank, item in enumerate(results['items']):
  427. id_ = item[0]
  428. if id_ == target:
  429. found = True
  430. found_rank = rank + 1
  431. break
  432. strdelta = str(delta).replace('0:00:0', '').replace('0:00:', '')
  433. print("%s: '%s' <- '%s' rank %d/%d (%s sec)" % (' ok' if found
  434. else 'MISS',
  435. target, query,
  436. found_rank,
  437. len(results['items']),
  438. strdelta))
  439. query_count += 1
  440. if found:
  441. hit_count += 1
  442. else:
  443. misses.append((query, target))
  444. if len(misses) != 0:
  445. print()
  446. print("MISSED:")
  447. for query, target in misses:
  448. print("%s '%s'" % (target, query))
  449. delta = datetime.now() - overall_start
  450. strdelta = str(delta).replace('0:00:0', '').replace('0:00:', '')
  451. print()
  452. print("Found %d / %d in %s" % (hit_count, query_count, strdelta))
  453. def _profile_test():
  454. # runs _test() with profiling, storing results in "norm.profile".
  455. # To see a profile, run e.g.
  456. # python -c 'import pstats;
  457. # pstats.Stats("norm.profile").strip_dirs().sort_stats("time").print_stats()'
  458. # | less
  459. import cProfile
  460. cProfile.run('_test()', 'norm.profile')
  461. if __name__ == '__main__':
  462. _test() # normal
  463. # _profile_test() # profiled