normdb.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351
  1. #!/usr/bin/env python
  2. # -*- Mode: Python; tab-width: 4; indent-tabs-mode: nil; coding: utf-8; -*-
  3. # vim:set ft=python ts=4 sw=4 sts=4 autoindent:
  4. """Functionality for normalization SQL database access."""
  5. import sqlite3 as sqlite
  6. import sys
  7. from os.path import join as path_join
  8. from os.path import sep as path_sep
  9. from os.path import exists
  10. try:
  11. from config import BASE_DIR, WORK_DIR
  12. except ImportError:
  13. # for CLI use; assume we're in brat server/src/ and config is in root
  14. from sys import path as sys_path
  15. from os.path import dirname
  16. sys_path.append(path_join(dirname(__file__), '../..'))
  17. from config import BASE_DIR, WORK_DIR
  18. # Filename extension used for DB file.
  19. DB_FILENAME_EXTENSION = 'db'
  20. # Names of tables with information on each entry
  21. TYPE_TABLES = ["names", "attributes", "infos"]
  22. # Names of tables that must have some value for an entry
  23. NON_EMPTY_TABLES = set(["names"])
  24. # Maximum number of variables in one SQL query (TODO: get from lib!)
  25. MAX_SQL_VARIABLE_COUNT = 999
  26. __query_count = {}
  27. class dbNotFoundError(Exception):
  28. def __init__(self, fn):
  29. self.fn = fn
  30. def __str__(self):
  31. return 'Database file "%s" not found' % self.fn
  32. # Normalizes a given string for search. Used to implement
  33. # case-insensitivity and similar in search.
  34. # NOTE: this is a different sense of "normalization" than that
  35. # implemented by a normalization DB as a whole: this just applies to
  36. # single strings.
  37. # NOTE2: it is critically important that this function is performed
  38. # identically during DB initialization and actual lookup.
  39. # TODO: enforce a single implementation.
  40. def string_norm_form(s):
  41. return s.lower().strip().replace('-', ' ')
  42. def __db_path(db):
  43. """Given a DB name/path, returns the path for the file that is expected to
  44. contain the DB."""
  45. # Assume we have a path relative to the brat root if the value
  46. # contains a separator, name only otherwise.
  47. # TODO: better treatment of name / path ambiguity, this doesn't
  48. # allow e.g. DBs to be located in brat root
  49. if path_sep in db:
  50. base = BASE_DIR
  51. else:
  52. base = WORK_DIR
  53. return path_join(base, db + '.' + DB_FILENAME_EXTENSION)
  54. def reset_query_count(dbname):
  55. global __query_count
  56. __query_count[dbname] = 0
  57. def get_query_count(dbname):
  58. global __query_count
  59. return __query_count.get(dbname, 0)
  60. def __increment_query_count(dbname):
  61. global __query_count
  62. __query_count[dbname] = __query_count.get(dbname, 0) + 1
  63. def _get_connection_cursor(dbname):
  64. # helper for DB access functions
  65. dbfn = __db_path(dbname)
  66. # open DB
  67. if not exists(dbfn):
  68. raise dbNotFoundError(dbfn)
  69. connection = sqlite.connect(dbfn)
  70. cursor = connection.cursor()
  71. return connection, cursor
  72. def _execute_fetchall(cursor, command, args, dbname):
  73. # helper for DB access functions
  74. cursor.execute(command, args)
  75. __increment_query_count(dbname)
  76. return cursor.fetchall()
  77. def data_by_id(dbname, id_):
  78. """Given a DB name and an entity id, returns all the information contained
  79. in the DB for the id."""
  80. connection, cursor = _get_connection_cursor(dbname)
  81. # select separately from names, attributes and infos
  82. responses = {}
  83. for table in TYPE_TABLES:
  84. command = '''
  85. SELECT L.text, N.value
  86. FROM entities E
  87. JOIN %s N
  88. ON E.id = N.entity_id
  89. JOIN labels L
  90. ON L.id = N.label_id
  91. WHERE E.uid=?''' % table
  92. responses[table] = _execute_fetchall(cursor, command, (id_, ), dbname)
  93. # short-circuit on missing or incomplete entry
  94. if table in NON_EMPTY_TABLES and len(responses[table]) == 0:
  95. break
  96. cursor.close()
  97. # empty or incomplete?
  98. for t in NON_EMPTY_TABLES:
  99. if len(responses[t]) == 0:
  100. return None
  101. # has content, format and return
  102. combined = []
  103. for t in TYPE_TABLES:
  104. combined.append(responses[t])
  105. return combined
  106. def ids_by_name(dbname, name, exactmatch=False, return_match=False):
  107. return ids_by_names(dbname, [name], exactmatch, return_match)
  108. def ids_by_names(dbname, names, exactmatch=False, return_match=False):
  109. if len(names) < MAX_SQL_VARIABLE_COUNT:
  110. return _ids_by_names(dbname, names, exactmatch, return_match)
  111. else:
  112. # break up into several queries
  113. result = []
  114. i = 0
  115. while i < len(names):
  116. n = names[i:i + MAX_SQL_VARIABLE_COUNT]
  117. r = _ids_by_names(dbname, n, exactmatch, return_match)
  118. result.extend(r)
  119. i += MAX_SQL_VARIABLE_COUNT
  120. return result
  121. def _ids_by_names(dbname, names, exactmatch=False, return_match=False):
  122. """Given a DB name and a list of entity names, returns the ids of all
  123. entities having one of the given names.
  124. Uses exact string lookup if exactmatch is True, otherwise performs
  125. normalized string lookup (case-insensitive etc.). If return_match is
  126. True, returns pairs of (id, matched name), otherwise returns only
  127. ids.
  128. """
  129. connection, cursor = _get_connection_cursor(dbname)
  130. if not return_match:
  131. command = 'SELECT E.uid'
  132. else:
  133. command = 'SELECT E.uid, N.value'
  134. command += '''
  135. FROM entities E
  136. JOIN names N
  137. ON E.id = N.entity_id
  138. '''
  139. if exactmatch:
  140. command += 'WHERE N.value IN (%s)' % ','.join(['?' for n in names])
  141. else:
  142. command += 'WHERE N.normvalue IN (%s)' % ','.join(['?' for n in names])
  143. names = [string_norm_form(n) for n in names]
  144. responses = _execute_fetchall(cursor, command, names, dbname)
  145. cursor.close()
  146. if not return_match:
  147. return [r[0] for r in responses]
  148. else:
  149. return [(r[0], r[1]) for r in responses]
  150. def ids_by_name_attr(dbname, name, attr, exactmatch=False, return_match=False):
  151. return ids_by_names_attr(dbname, [name], attr, exactmatch, return_match)
  152. def ids_by_names_attr(dbname, names, attr, exactmatch=False,
  153. return_match=False):
  154. if len(names) < MAX_SQL_VARIABLE_COUNT - 1:
  155. return _ids_by_names_attr(
  156. dbname, names, attr, exactmatch, return_match)
  157. else:
  158. # break up
  159. result = []
  160. i = 0
  161. while i < len(names):
  162. # -1 for attr
  163. n = names[i:i + MAX_SQL_VARIABLE_COUNT - 1]
  164. r = _ids_by_names_attr(dbname, n, attr, exactmatch, return_match)
  165. result.extend(r)
  166. i += MAX_SQL_VARIABLE_COUNT - 1
  167. return result
  168. def _ids_by_names_attr(dbname, names, attr, exactmatch=False,
  169. return_match=False):
  170. """Given a DB name, a list of entity names, and an attribute text, returns
  171. the ids of all entities having one of the given names and an attribute
  172. matching the given attribute.
  173. Uses exact string lookup if exactmatch is True, otherwise performs
  174. normalized string lookup (case-insensitive etc.). If return_match is
  175. True, returns pairs of (id, matched name), otherwise returns only
  176. names.
  177. """
  178. connection, cursor = _get_connection_cursor(dbname)
  179. if not return_match:
  180. command = 'SELECT E.uid'
  181. else:
  182. command = 'SELECT E.uid, N.value'
  183. command += '''
  184. FROM entities E
  185. JOIN names N
  186. ON E.id = N.entity_id
  187. JOIN attributes A
  188. ON E.id = A.entity_id
  189. '''
  190. if exactmatch:
  191. command += 'WHERE N.value IN (%s) AND A.value=?' % ','.join([
  192. '?' for n in names])
  193. else:
  194. # NOTE: using 'LIKE', not '=' here
  195. command += 'WHERE N.normvalue IN (%s) AND A.normvalue LIKE ?' % ','.join([
  196. '?' for n in names])
  197. attr = '%' + string_norm_form(attr) + '%'
  198. names = [string_norm_form(n) for n in names]
  199. responses = _execute_fetchall(cursor, command, names + [attr], dbname)
  200. cursor.close()
  201. if not return_match:
  202. return [r[0] for r in responses]
  203. else:
  204. return [(r[0], r[1]) for r in responses]
  205. def datas_by_ids(dbname, ids):
  206. if len(ids) < MAX_SQL_VARIABLE_COUNT:
  207. return _datas_by_ids(dbname, ids)
  208. else:
  209. # break up
  210. datas = {}
  211. i = 0
  212. ids = list(ids)
  213. while i < len(ids):
  214. ids_ = ids[i:i + MAX_SQL_VARIABLE_COUNT]
  215. r = _datas_by_ids(dbname, ids_)
  216. for k in r:
  217. datas[k] = r[k]
  218. i += MAX_SQL_VARIABLE_COUNT
  219. return datas
  220. def _datas_by_ids(dbname, ids):
  221. """Given a DB name and a list of entity ids, returns all the information
  222. contained in the DB for the ids."""
  223. connection, cursor = _get_connection_cursor(dbname)
  224. # select separately from names, attributes and infos
  225. responses = {}
  226. for table in TYPE_TABLES:
  227. command = '''
  228. SELECT E.uid, L.text, N.value
  229. FROM entities E
  230. JOIN %s N
  231. ON E.id = N.entity_id
  232. JOIN labels L
  233. ON L.id = N.label_id
  234. WHERE E.uid IN (%s)''' % (table, ','.join(['?' for i in ids]))
  235. response = _execute_fetchall(cursor, command, list(ids), dbname)
  236. # group by ID first
  237. for id_, label, value in response:
  238. if id_ not in responses:
  239. responses[id_] = {}
  240. if table not in responses[id_]:
  241. responses[id_][table] = []
  242. responses[id_][table].append([label, value])
  243. # short-circuit on missing or incomplete entry
  244. if (table in NON_EMPTY_TABLES and
  245. len([i for i in responses if responses[i][table] == 0]) != 0):
  246. return None
  247. cursor.close()
  248. # empty or incomplete?
  249. for id_ in responses:
  250. for t in NON_EMPTY_TABLES:
  251. if len(responses[id_][t]) == 0:
  252. return None
  253. # has expected content, format and return
  254. datas = {}
  255. for id_ in responses:
  256. datas[id_] = []
  257. for t in TYPE_TABLES:
  258. datas[id_].append(responses[id_].get(t, []))
  259. return datas
  260. def datas_by_name(dbname, name, exactmatch=False):
  261. # TODO: optimize
  262. datas = {}
  263. for id_ in ids_by_name(dbname, name, exactmatch):
  264. datas[id_] = data_by_id(dbname, id_)
  265. return datas
  266. if __name__ == "__main__":
  267. # test
  268. if len(sys.argv) > 1:
  269. dbname = sys.argv[1]
  270. else:
  271. dbname = "FMA"
  272. if len(sys.argv) > 2:
  273. id_ = sys.argv[2]
  274. else:
  275. id_ = "10883"
  276. print(data_by_id(dbname, id_))
  277. print(ids_by_name(dbname, 'Pleural branch of left sixth posterior intercostal artery'))
  278. print(datas_by_name(dbname, 'Pleural branch of left sixth posterior intercostal artery'))