role_labeling.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422
  1. import psycopg2
  2. import codecs
  3. import xlwt
  4. import xlrd
  5. import os
  6. import re
  7. from xlutils.copy import copy
  8. from BiddingKG.dl.common.Utils import *
  9. import pandas as pd
  10. import math
  11. def getData(t="final_label_role"):
  12. '''
  13. '''
  14. conn = psycopg2.connect(dbname="BiddingKM_test_10000",user="postgres",password="postgres",host="192.168.2.101")
  15. cursor = conn.cursor()
  16. select_sql = " select A.doc_id,C.entity_id,C.label,case when C.label=0 then '招标人' when C.label=1 then '招标代理' when C.label=2 then '中标人/第一候选' when C.label=3 then '第二' when C.label=4 then '第三' else '无' end as 再标注,case when D.label=0 then '招标人' when D.label=1 then '招标代理' when D.label=2 then '中标人/第一候选' when D.label=3 then '第二' when D.label=4 then '第三' else '无' end as 原标注,B.entity_text,A.tokens[B.begin_index-10:B.begin_index],A.tokens[B.begin_index+1:B.end_index+1],A.tokens[B.end_index+2:B.end_index+12] "
  17. group_sql = " group by A.doc_id,C.entity_id,C.label,D.label,B.entity_text,B.begin_index,B.end_index,A.tokens,A.sentence_index "
  18. sql = select_sql+" from sentences A,entity_mention B,"+t+" C,label_guest_role D where A.doc_id=B.doc_id and A.sentence_index=B.sentence_index and B.entity_id=C.entity_id and C.entity_id=D.entity_id and C.label!=D.label "+group_sql+"order by A.doc_id,A.sentence_index asc,D.label asc"
  19. cursor.execute(sql)
  20. result = []
  21. rows = cursor.fetchall()
  22. for row in rows:
  23. item = []
  24. for column in row:
  25. item.append(column)
  26. result.append(item)
  27. conn.close()
  28. return result
  29. def labeling(datas):
  30. '''
  31. @summary:标注数据
  32. @param:
  33. datas:待标注数据,包括doc_id,entity_id,标注值,上下文
  34. '''
  35. sum = 0
  36. row_index = 0
  37. begin_doc_id = str(input("开始文章是:"))
  38. begin_index = 0
  39. end_index = len(datas)-1
  40. find_flag = False
  41. while(row_index<len(datas)):
  42. row = datas[row_index]
  43. if begin_doc_id!="" and begin_index==0:
  44. if row[0]==begin_doc_id:
  45. begin_index = row_index
  46. else:
  47. row_index += 1
  48. continue
  49. find_flag = True
  50. print(row[0])
  51. print(row[3],row[4],row[5])
  52. print("before",row[6])
  53. print("entity",row[7])
  54. print("after",row[8])
  55. while(True):
  56. l = str(input("标签为:"))
  57. if l in ["0","1","2","3","4","5","","8","9"]:
  58. break
  59. if l=="0":
  60. row[2] = 0
  61. elif l=="1":
  62. row[2] = 1
  63. elif l=="2":
  64. row[2] = 2
  65. elif l=="3":
  66. row[2] = 3
  67. elif l=="4":
  68. row[2] = 4
  69. elif l=="5":
  70. row[2] = 5
  71. elif l=="":
  72. pass
  73. elif l=="8":
  74. row_index -= 1
  75. sum -= 1
  76. continue
  77. elif l=="9":
  78. end_index = row_index-1
  79. break
  80. sum += 1
  81. row_index += 1
  82. print("sum:",sum)
  83. if find_flag:
  84. with codecs.open("relabel.txt","a",encoding="utf8") as f:
  85. for row in datas[begin_index:end_index+1]:
  86. f.write(str(row[1]))
  87. f.write("\t")
  88. f.write(str(row[2]))
  89. f.write("\n")
  90. f.flush()
  91. f.close()
  92. #设置表格样式
  93. def set_style(name,height,bold=False):
  94. style = xlwt.XFStyle()
  95. font = xlwt.Font()
  96. font.name = name
  97. font.bold = bold
  98. font.color_index = 4
  99. font.height = height
  100. style.font = font
  101. return style
  102. def getDatasToExcel():
  103. '''
  104. @summary:取出待标注数据到excel中
  105. '''
  106. roles = ["0_招标人","1_招标代理","2_中标第一候选","3_第二候选","4_第三候选","5_无"]
  107. conn = psycopg2.connect(dbname="BiddingKM_test_10000",user="postgres",password="postgres",host="192.168.2.101")
  108. cursor = conn.cursor()
  109. nums = 3
  110. for role in roles:
  111. select_sql = " select C.entity_id,C.label,A.tokens[B.begin_index-10:B.begin_index],A.tokens[B.begin_index+1:B.end_index+1],A.tokens[B.end_index+2:B.end_index+12],case when C.label=0 then '招标人' when C.label=1 then '招标代理' when C.label=2 then '中标人/第一候选' when C.label=3 then '第二候选' when C.label=4 then '第三候选' else '无' end as 再标注 "
  112. sql = select_sql+" from sentences A,entity_mention B,final_label_role C where C.label="+role.split("_")[0]+" and A.doc_id=B.doc_id and A.sentence_index=B.sentence_index and B.entity_id=C.entity_id and not exists(select 1 from relabel where C.entity_id=relabel.entity_id) order by C.label asc,A.doc_id,A.sentence_index asc limit 20000"
  113. print(sql)
  114. cursor.execute(sql)
  115. rows = cursor.fetchall()
  116. parts = len(rows)//3
  117. for nums_i in range(nums):
  118. file = xlwt.Workbook()
  119. sheet = file.add_sheet("标注"+role,cell_overwrite_ok=True)
  120. row_head = ["entity_id","标注id","实体前","实体","实体后","角色","正确?(1-正确,0-错误)"]
  121. row_index = 0
  122. style = set_style('Times New Roman',220,True)
  123. for i in range(len(row_head)):
  124. sheet.write(row_index,i,row_head[i],style)
  125. row_index += 1
  126. if nums_i<nums-1:
  127. for row in rows[nums_i*parts:(nums_i+1)*parts]:
  128. for i in range(len(row)):
  129. sheet.write(row_index,i,row[i],style)
  130. row_index += 1
  131. else:
  132. for row in rows[nums_i*parts:]:
  133. for i in range(len(row)):
  134. sheet.write(row_index,i,row[i],style)
  135. row_index += 1
  136. file.save("标注"+role.split("_")[1]+str(nums_i)+".xls")
  137. conn.close()
  138. def getDatasFromExcel():
  139. '''
  140. @summary:从已经标注的excel中取出标注数据
  141. '''
  142. home = "./label/role_done/"
  143. col_entity_id = 0
  144. col_label = 1
  145. col_flag = 6
  146. table = "hand_label_role"
  147. conn = psycopg2.connect(dbname="BiddingKM_test_10000",user="postgres",password="postgres",host="192.168.2.101")
  148. cursor = conn.cursor()
  149. cursor.execute(" select to_regclass('"+table+"') is null ")
  150. notExists = cursor.fetchall()[0][0]
  151. if notExists:
  152. cursor.execute(" create table "+table+" (entity_id text,label int)")
  153. else:
  154. cursor.execute(" delete from "+table)
  155. conn.commit()
  156. for file in os.listdir(home):
  157. if os.path.isfile(home+file):
  158. book = xlrd.open_workbook(home+file)
  159. sheet = book.sheet_by_index(0)
  160. for row_index in range(1,sheet.nrows):
  161. print(row_index,file)
  162. if len(str(sheet.cell_value(row_index,col_flag)))>0 and (int(sheet.cell_value(row_index,col_flag))>0):
  163. sql = " insert into "+table+"(entity_id,label) values('"+str(sheet.cell_value(row_index,col_entity_id))+"',"+str(int(sheet.cell_value(row_index,col_label)))+")"
  164. cursor.execute(sql)
  165. conn.commit()
  166. conn.close()
  167. def selectWrongDatasFromExcel():
  168. '''
  169. @summary:取出标注为错误的数据
  170. '''
  171. home = "./label/role_done/"
  172. data = []
  173. toExcel_file = "../../dl_dev/role/label/role_done/候选中因序号标错到无.xls"
  174. toExcel = xlwt.Workbook()
  175. toExcel_sheet = toExcel.add_sheet("错误标注到无",cell_overwrite_ok=True)
  176. row_head = ["entity_id","标注id","实体前","实体","实体后","角色","正确?(1-正确,0-错误)"]
  177. row_index_toExcel = 0
  178. style = set_style('Times New Roman',220,True)
  179. for i in range(len(row_head)):
  180. toExcel_sheet.write(row_index_toExcel,i,row_head[i],style)
  181. row_index_toExcel += 1
  182. for file in os.listdir(home):
  183. if os.path.isfile(home+file):
  184. if re.search(re.compile("第[一二三]"),file) is not None:
  185. book = xlrd.open_workbook(home+file)
  186. sheet = book.sheet_by_index(0)
  187. changeBook = copy(book)
  188. changeSheet = changeBook.get_sheet(0)
  189. for row_index in range(0,sheet.nrows):
  190. if re.search(re.compile("排名|排序|名次|第[123一二三]|(中标|成交)(人|单位|供应商)|成交情况"),str(sheet.cell_value(row_index,2))) is None:
  191. if re.search(re.compile("序号[::][123]"),str(sheet.cell_value(row_index,4))) is not None:
  192. print(file,sheet.row_values(row_index))
  193. changeSheet.write(row_index,6,0)
  194. row = sheet.row_values(row_index)
  195. row[1] = 5
  196. row[5] = "无"
  197. row[6] = 1
  198. for i in range(len(row)):
  199. toExcel_sheet.write(row_index_toExcel,i,row[i],style)
  200. row_index_toExcel += 1
  201. changeBook.save(home+"".join(file.split(".")[:-1])+"修改序号.xls")
  202. toExcel.save(toExcel_file)
  203. def exportHandLabelData():
  204. conn = psycopg2.connect(dbname="BiddingKM_test_10000",user="postgres",password="postgres",host="192.168.2.101")
  205. cursor = conn.cursor()
  206. sql = '''
  207. select A.entity_id,A.entity_text,A.begin_index,A.end_index,C.label,B.tokens
  208. from entity_mention A,sentences B ,hand_label_role C
  209. where A.doc_id=B.doc_id and A.sentence_index=B.sentence_index and A.entity_id=C.entity_id
  210. and A.entity_type in ('org','company')
  211. order by C.label
  212. '''
  213. cursor.execute(sql)
  214. rows = cursor.fetchall()
  215. list_entity_id = []
  216. list_before = []
  217. list_after = []
  218. list_text = []
  219. list_label = []
  220. repeat = set()
  221. for row in rows:
  222. entity_id = row[0]
  223. entity_text = row[1]
  224. begin_index = row[2]
  225. end_index = row[3]
  226. label = int(row[4])
  227. tokens = row[5]
  228. beforeafter = spanWindow(tokens, begin_index, end_index, 10)
  229. if ("".join(beforeafter[0]),entity_text,"".join(beforeafter[1])) in repeat:
  230. continue
  231. if str(label)!="5":
  232. continue
  233. repeat.add(("".join(beforeafter[0]),entity_text,"".join(beforeafter[1])))
  234. list_entity_id.append(entity_id)
  235. list_before.append("".join(beforeafter[0]))
  236. list_after.append("".join(beforeafter[1]))
  237. list_text.append(entity_text)
  238. list_label.append(label)
  239. print("len",len(list_entity_id))
  240. parts = 1
  241. parts_num = len(list_entity_id)//parts
  242. for i in range(parts-1):
  243. data = {"entity_id":list_entity_id[i*parts_num:(i+1)*parts_num],"list_before":list_before[i*parts_num:(i+1)*parts_num],"list_after":list_after[i*parts_num:(i+1)*parts_num],"list_text":list_text[i*parts_num:(i+1)*parts_num],"list_label":list_label[i*parts_num:(i+1)*parts_num]}
  244. df = pd.DataFrame(data)
  245. df.to_excel("原先标注数据_role_"+str(i)+".xls",columns=["entity_id","list_before","list_text","list_after","list_label"])
  246. i = parts - 1
  247. data = {"entity_id":list_entity_id[i*parts_num:],"list_before":list_before[i*parts_num:],"list_after":list_after[i*parts_num:],"list_text":list_text[i*parts_num:],"list_label":list_label[i*parts_num:]}
  248. df = pd.DataFrame(data)
  249. df.to_excel("角色无数据_role_"+str(i)+".xls",columns=["entity_id","list_before","list_text","list_after","list_label"])
  250. def selectWithRule(source,filter,target):
  251. assert target not in filter
  252. assert source!=target
  253. dict_source = pd.read_excel(source)
  254. set_filter = set()
  255. for filt in filter:
  256. set_filter = set_filter | set(pd.read_excel(filt)["entity_id"])
  257. list_entity_id = []
  258. list_before = []
  259. list_text = []
  260. list_after = []
  261. list_label = []
  262. for id,before,text,after,label in zip(dict_source["entity_id"],dict_source["list_before"],dict_source["list_text"],dict_source["list_after"],dict_source["list_label"]):
  263. if id in set_filter:
  264. continue
  265. if re.search("",str(before)) is not None:
  266. list_entity_id.append(id)
  267. list_before.append(before)
  268. list_text.append(text)
  269. list_after.append(after)
  270. list_label.append(label)
  271. data = {"entity_id":list_entity_id,"list_before":list_before,"list_text":list_text,"list_after":list_after,"list_label":list_label}
  272. columns = ["entity_id","list_before","list_text","list_after","list_label"]
  273. df = pd.DataFrame(data)
  274. df.to_excel(target,index=False,columns=columns)
  275. def importreHandLabelData():
  276. conn = psycopg2.connect(dbname="BiddingKM_test_10000",user="postgres",password="postgres",host="192.168.2.101")
  277. cursor = conn.cursor()
  278. table = "hand_label_role_0409"
  279. files = ["待检测原先标注数据_role_11.xls","待检查原先标注数据_role_00.xls","批量.xls"]
  280. for file in files:
  281. df = pd.read_excel(file)
  282. for entity_id,label,turn in zip(df["entity_id"],df["list_label"],df["turn"]):
  283. new_label = label
  284. #print(entity_id)
  285. if not math.isnan(turn):
  286. new_label = turn
  287. sql = " insert into "+table+"(entity_id,label) values('"+entity_id+"',"+str(new_label)+") "
  288. cursor.execute(sql)
  289. conn.commit()
  290. conn.close()
  291. def dumpData():
  292. conn = psycopg2.connect(dbname="BiddingKM_test_10000",user="postgres",password="postgres",host="192.168.2.101")
  293. cursor = conn.cursor()
  294. sql = " select B.entity_id,A.tokens,B.entity_text,B.begin_index,B.end_index,C.label from sentences A,entity_mention_copy B,hand_label_role_0409 C where B.entity_type in ('org','company') and A.doc_id=B.doc_id and A.sentence_index=B.sentence_index and B.entity_id=C.entity_id "
  295. cursor.execute(sql)
  296. rows = cursor.fetchall()
  297. save(rows,"id_token_text_begin_end_label.pk")
  298. conn.close()
  299. def relabel():
  300. pkfiles = ["id_token_text_begin_end_label.pk","id_token_text_begin_end_label.pk1"]
  301. list_id = []
  302. list_before = []
  303. list_text = []
  304. list_after = []
  305. list_label = []
  306. for file in pkfiles:
  307. for row in load(file):
  308. id = row[0]
  309. token = row[1]
  310. text = row[2]
  311. begin = int(row[3])
  312. end = int(row[4])
  313. label = int(row[5])
  314. span = spanWindow(token, begin, end, size=10, center_include=True, word_flag=True)
  315. before = span[0]
  316. center = span[1]
  317. after = span[2]
  318. if re.search("中标人.{,3}$",before) is not None:
  319. list_id.append(id)
  320. list_before.append(before)
  321. list_text.append(center)
  322. list_after.append(after)
  323. list_label.append(label)
  324. df = pd.DataFrame({"list_id":list_id,"list_before":list_before,"list_text":list_text,"list_after":list_after,"list_label":list_label})
  325. df.to_excel("rule1.xls",columns=["list_id","list_before","list_text","list_after","list_label"],index=False)
  326. def importAfterrelabel():
  327. conn = psycopg2.connect(dbname="BiddingKM_test_10000",user="postgres",password="postgres",host="192.168.2.101")
  328. cursor = conn.cursor()
  329. conn_1 = psycopg2.connect(dbname="article_label",user="postgres",password="postgres",host="192.168.2.101")
  330. cursor_1 = conn_1.cursor()
  331. df = pd.read_excel("rule1.xls")
  332. list_id = df["list_id"]
  333. list_label = df["list_label"]
  334. count = 0
  335. for id,label in zip(list_id,list_label):
  336. if re.search("比地",id) is not None:
  337. sql = " update turn_label set new_label='"+str(int(label))+"' where entity_id='"+id+"' "
  338. cursor_1.execute(sql)
  339. else:
  340. sql = " update hand_label_role_0409 set label="+str(int(label))+" where entity_id='"+id+"' "
  341. cursor.execute(sql)
  342. count += 1
  343. print("done",count)
  344. conn.commit()
  345. conn_1.commit()
  346. conn.close()
  347. conn_1.close()
  348. if __name__=="__main__":
  349. pass
  350. #labeling()
  351. #getDatasToExcel()
  352. #getDatasFromExcel()
  353. #selectWrongDatasFromExcel()
  354. #exportHandLabelData()
  355. #selectWithRule("角色无数据_role_0.xls",["批量.xls"],"公告公示.xls")
  356. #importreHandLabelData()
  357. #dumpData()
  358. #relabel()
  359. #importAfterrelabel()