definition.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. import pickle
  2. import requests
  3. import json
  4. from ipywidgets import widgets
  5. from IPython.display import display,clear_output
  6. import os
  7. os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  8. os.environ["CUDA_VISIBLE_DEVICES"] = ""
  9. def getHbox(entity):
  10. check = False if entity[5]=="1" else True
  11. return widgets.HBox([widgets.ToggleButton(
  12. value=check,
  13. description='表述错误',
  14. disabled=False,
  15. layout=widgets.Layout(width="100px",height="100px"),
  16. icon='check'
  17. ),
  18. widgets.Label(value="表述:",layout=widgets.Layout(width="60px",height="100px")),
  19. widgets.Textarea(value=getBS(entity),layout=widgets.Layout(width="170px",height="100px")),
  20. widgets.Label(value="前后文:",layout=widgets.Layout(width="100px",height="100px")),
  21. widgets.Textarea(value="".join(entity[0]),layout=widgets.Layout(width="170px",height="100px")),
  22. widgets.Textarea(value="".join(entity[1]),layout=widgets.Layout(width="170px",height="100px")),
  23. widgets.Textarea(value="".join(entity[2]),layout=widgets.Layout(width="170px",height="100px"))])
  24. def save(object_to_save, path):
  25. '''
  26. 保存对象
  27. @Arugs:
  28. object_to_save: 需要保存的对象
  29. @Return:
  30. 保存的路径
  31. '''
  32. with open(path, 'wb') as f:
  33. pickle.dump(object_to_save, f)
  34. def load(path):
  35. '''
  36. 读取对象
  37. @Arugs:
  38. path: 读取的路径
  39. @Return:
  40. 读取的对象
  41. '''
  42. with open(path, 'rb') as f:
  43. object1 = pickle.load(f)
  44. return object1
  45. guardian_base = 'http://127.0.0.1:15010'
  46. myheaders = {'Content-Type': 'application/json'}
  47. source_data_file = "data.pk"
  48. import psycopg2
  49. from DBUtils.PooledDB import PooledDB
  50. pool = None
  51. def getConnection():
  52. global pool
  53. if pool is None:
  54. pool = PooledDB(psycopg2, 5,5,dbname="article_label", host="192.168.2.101",user="postgres",password="postgres",port="5432")
  55. return pool.connection()
  56. def make(index_,source_data):
  57. user = {
  58. "id": source_data[index_][0],
  59. "content":source_data[index_][1]
  60. }
  61. _resp = requests.post(guardian_base + '/article_extract', json=user, headers=myheaders, verify=True)
  62. return json.loads(_resp.content.decode("utf-8"))["success"] is True
  63. BS_dic = {"org":{"0":"角色-招标人","1":"角色-代理人","2":"角色-中标/第一候选人","3":"角色-第二候选人","4":"角色-第三候选人","5":"角色-无"},
  64. "company":{"0":"角色-招标人","1":"角色-代理人","2":"角色-中标/第一候选人","3":"角色-第二候选人","4":"角色-第三候选人","5":"角色-无"},
  65. "money":{"0":"金额-招标金额","1":"金额-中投标金额","2":"金额-其他金额"},
  66. "person":{"0":"联系人-非目标联系人","1":"联系人-招标联系人","2":"联系人-代理联系人","3":"联系人-联系人"}}
  67. def getBS(entity):
  68. return BS_dic[entity[3]][entity[4]]
  69. def getEntitys(index_,source_data):
  70. id = source_data[index_][0]
  71. conn = getConnection()
  72. cursor = conn.cursor()
  73. sql = " select B.tokens,A.entity_text,A.entity_type,A.label,A.handlabel,A.entity_id,A.begin_index,A.end_index,A.values from entity_mention A,sentences B where A.doc_id=B.doc_id and A.sentence_index=B.sentence_index and A.label !='None' "+\
  74. " and B.doc_id='"+id+"' order by A.label,A.entity_type "
  75. cursor.execute(sql)
  76. rows = cursor.fetchall()
  77. data = []
  78. for row in rows:
  79. tokens = row[0]
  80. entity_text = row[1]
  81. entity_type = row[2]
  82. label = row[3]
  83. handlabel = row[4]
  84. entity_id = row[5]
  85. begin_index = row[6]
  86. end_index = row[7]
  87. values = row[8]
  88. prob = values[1:-1].split(",")[int(label)]
  89. if float(prob)<0.5:
  90. continue
  91. span = spanWindow(tokens,begin_index,end_index,10)
  92. data.append([span[0],span[1],span[2],entity_type,label,handlabel,entity_id])
  93. conn.close()
  94. return data
  95. def spanWindow(tokens,begin_index,end_index,size):
  96. '''
  97. @summary:取得某个实体的上下文词汇
  98. @param:
  99. tokens:句子分词list
  100. begin_index:实体的开始index
  101. end_index:实体的结束index
  102. size:左右两边各取多少个词
  103. @return: list,实体的上下文词汇
  104. '''
  105. length_tokens = len(tokens)
  106. if begin_index>size:
  107. begin = begin_index-size
  108. else:
  109. begin = 0
  110. if end_index+size<length_tokens:
  111. end = end_index+size+1
  112. else:
  113. end = length_tokens
  114. result = []
  115. result.append(tokens[begin:begin_index])
  116. result.append(tokens[begin_index:end_index+1])
  117. result.append(tokens[end_index+1:end])
  118. return result
  119. def getCodeName(index_,source_data):
  120. id = source_data[index_][0]
  121. conn = getConnection()
  122. cursor = conn.cursor()
  123. sql = " select code,name from articles_processed where id='"+id+"' "
  124. cursor.execute(sql)
  125. rows = cursor.fetchall()
  126. conn.close()
  127. if len(rows)>0:
  128. return rows[0][0],rows[0][1]
  129. else:
  130. return "",""
  131. def saveData(datas,out_code,begin_index,source_data,out_name,out_vbox):
  132. if out_code.value=="" and out_name.value=="":
  133. print("请标注编号名称")
  134. return 1
  135. conn = getConnection()
  136. cursor = conn.cursor()
  137. sql = " update articles_processed set code='"+out_code.value+"',name='"+out_name.value+"' where id='"+source_data[begin_index][0]+"'"
  138. cursor.execute(sql)
  139. for i in range(len(datas)):
  140. handlabel = "0" if out_vbox.children[i].children[0].value else "1"
  141. if handlabel == "0":
  142. sql = " update entity_mention set handlabel='"+handlabel+"' where entity_id='"+datas[i][6]+"' and entity_type='"+datas[i][3]+"'"
  143. cursor.execute(sql)
  144. conn.commit()
  145. conn.close()
  146. return 0