generateData.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716
  1. '''
  2. Created on 2019年3月25日
  3. @author: User
  4. '''
  5. import glob
  6. import re
  7. import copy
  8. from bs4 import BeautifulSoup
  9. import codecs
  10. import pandas as pd
  11. from BiddingKG.dl.interface.predictor import *
  12. from BiddingKG.dl.form.feature import *
  13. import psycopg2
  14. from BiddingKG.dl.common.Utils import *
  15. # formPredictor = FormPredictor()
  16. def tableToText(soup,data,file,data_set_is,data_set_no):
  17. '''
  18. @param:
  19. soup:网页html的soup
  20. @return:处理完表格信息的网页text
  21. '''
  22. def getTrs(tbody):
  23. #获取所有的tr
  24. trs = []
  25. objs = tbody.find_all(recursive=False)
  26. for obj in objs:
  27. if obj.name=="tr":
  28. trs.append(obj)
  29. if obj.name=="tbody":
  30. for tr in obj.find_all("tr",recursive=False):
  31. trs.append(tr)
  32. return trs
  33. def fixSpan(tbody):
  34. # 处理colspan, rowspan信息补全问题
  35. #trs = tbody.findChildren('tr', recursive=False)
  36. trs = getTrs(tbody)
  37. ths_len = 0
  38. ths = list()
  39. trs_set = set()
  40. #修改为先进行列补全再进行行补全,否则可能会出现表格解析混乱
  41. # 遍历每一个tr
  42. for indtr, tr in enumerate(trs):
  43. ths_tmp = tr.findChildren('th', recursive=False)
  44. #不补全含有表格的tr
  45. if len(tr.findChildren('table'))>0:
  46. continue
  47. if len(ths_tmp) > 0:
  48. ths_len = ths_len + len(ths_tmp)
  49. for th in ths_tmp:
  50. ths.append(th)
  51. trs_set.add(tr)
  52. # 遍历每行中的element
  53. tds = tr.findChildren(recursive=False)
  54. for indtd, td in enumerate(tds):
  55. # 若有colspan 则补全同一行下一个位置
  56. if 'colspan' in td.attrs:
  57. if str(re.sub("[^0-9]","",str(td['colspan'])))!="":
  58. col = int(re.sub("[^0-9]","",str(td['colspan'])))
  59. td['colspan'] = 1
  60. for i in range(1, col, 1):
  61. td.insert_after(copy.copy(td))
  62. for indtr, tr in enumerate(trs):
  63. ths_tmp = tr.findChildren('th', recursive=False)
  64. #不补全含有表格的tr
  65. if len(tr.findChildren('table'))>0:
  66. continue
  67. if len(ths_tmp) > 0:
  68. ths_len = ths_len + len(ths_tmp)
  69. for th in ths_tmp:
  70. ths.append(th)
  71. trs_set.add(tr)
  72. # 遍历每行中的element
  73. tds = tr.findChildren(recursive=False)
  74. for indtd, td in enumerate(tds):
  75. # 若有rowspan 则补全下一行同样位置
  76. if 'rowspan' in td.attrs:
  77. if str(re.sub("[^0-9]","",str(td['rowspan'])))!="":
  78. row = int(re.sub("[^0-9]","",str(td['rowspan'])))
  79. td['rowspan'] = 1
  80. for i in range(1, row, 1):
  81. # 获取下一行的所有td, 在对应的位置插入
  82. if indtr+i<len(trs):
  83. tds1 = trs[indtr + i].findChildren(['td','th'], recursive=False)
  84. if len(tds1) >= (indtd) and len(tds1)>0:
  85. if indtd > 0:
  86. tds1[indtd - 1].insert_after(copy.copy(td))
  87. else:
  88. tds1[0].insert_before(copy.copy(td))
  89. def getTable(tbody):
  90. #trs = tbody.findChildren('tr', recursive=False)
  91. trs = getTrs(tbody)
  92. inner_table = []
  93. for tr in trs:
  94. tr_line = []
  95. tds = tr.findChildren(['td','th'], recursive=False)
  96. for td in tds:
  97. tr_line.append([re.sub('\s*','',td.get_text()),0])
  98. inner_table.append(tr_line)
  99. return inner_table
  100. #处理表格不对齐的问题
  101. def fixTable(inner_table):
  102. maxWidth = 0
  103. for item in inner_table:
  104. if len(item)>maxWidth:
  105. maxWidth = len(item)
  106. for i in range(len(inner_table)):
  107. if len(inner_table[i])<maxWidth:
  108. for j in range(maxWidth-len(inner_table[i])):
  109. inner_table[i].append(["",0])
  110. return inner_table
  111. def removePadding(inner_table,pad_row = "@@",pad_col = "##"):
  112. height = len(inner_table)
  113. width = len(inner_table[0])
  114. for i in range(height):
  115. point = ""
  116. for j in range(width):
  117. if inner_table[i][j][0]==point and point!="":
  118. inner_table[i][j][0] = pad_row
  119. else:
  120. if inner_table[i][j][0] not in [pad_row,pad_col]:
  121. point = inner_table[i][j][0]
  122. for j in range(width):
  123. point = ""
  124. for i in range(height):
  125. if inner_table[i][j][0]==point and point!="":
  126. inner_table[i][j][0] = pad_col
  127. else:
  128. if inner_table[i][j][0] not in [pad_row,pad_col]:
  129. point = inner_table[i][j][0]
  130. def addPadding(inner_table,pad_row = "@@",pad_col = "##"):
  131. height = len(inner_table)
  132. width = len(inner_table[0])
  133. for i in range(height):
  134. for j in range(width):
  135. if inner_table[i][j][0]==pad_row:
  136. inner_table[i][j][0] = inner_table[i][j-1][0]
  137. inner_table[i][j][1] = inner_table[i][j-1][1]
  138. if inner_table[i][j][0]==pad_col:
  139. inner_table[i][j][0] = inner_table[i-1][j][0]
  140. inner_table[i][j][1] = inner_table[i-1][j][1]
  141. #设置表头
  142. def setHead(inner_table,prob_min=0.64):
  143. pad_row = "@@"
  144. pad_col = "##"
  145. removePadding(inner_table, pad_row, pad_col)
  146. pad_pattern = re.compile(pad_row+"|"+pad_col)
  147. height = len(inner_table)
  148. width = len(inner_table[0])
  149. head_list = []
  150. head_list.append(0)
  151. #行表头
  152. is_head_last = False
  153. for i in range(height):
  154. is_head = False
  155. is_long_value = False
  156. #判断是否是全padding值
  157. is_same_value = True
  158. same_value = inner_table[i][0][0]
  159. for j in range(width):
  160. if inner_table[i][j][0]!=same_value and inner_table[i][j][0]!=pad_row:
  161. is_same_value = False
  162. break
  163. #predict is head or not with model
  164. temp_item = ""
  165. for j in range(width):
  166. temp_item += inner_table[i][j][0]+"|"
  167. temp_item = re.sub(pad_pattern,"",temp_item)
  168. form_prob = formPredictor.predict(encoding(temp_item,expand=True))
  169. if form_prob is not None:
  170. if form_prob[0][1]>prob_min:
  171. is_head = True
  172. else:
  173. is_head = False
  174. #print(temp_item,form_prob)
  175. if len(inner_table[i][0][0])>40:
  176. is_long_value = True
  177. if is_head or is_long_value or is_same_value:
  178. #不把连续表头分开
  179. if not is_head_last:
  180. head_list.append(i)
  181. if is_long_value or is_same_value:
  182. head_list.append(i+1)
  183. if is_head:
  184. for j in range(width):
  185. if inner_table[i][j][0] not in data_set_is and inner_table[i][j][0] not in data_set_no:
  186. data.append([file,inner_table[i][j][0],1])
  187. data_set_is.add(inner_table[i][j][0])
  188. inner_table[i][j][1] = 1
  189. is_head_last = is_head
  190. head_list.append(height)
  191. #列表头
  192. for i in range(len(head_list)-1):
  193. head_begin = head_list[i]
  194. head_end = head_list[i+1]
  195. #最后一列不设置为列表头
  196. for i in range(width-1):
  197. is_head = False
  198. #predict is head or not with model
  199. temp_item = ""
  200. for j in range(head_begin,head_end):
  201. temp_item += inner_table[j][i][0]+"|"
  202. temp_item = re.sub(pad_pattern,"",temp_item)
  203. form_prob = formPredictor.predict(encoding(temp_item,expand=True))
  204. if form_prob is not None:
  205. if form_prob[0][1]>prob_min:
  206. is_head = True
  207. else:
  208. is_head = False
  209. if is_head:
  210. for j in range(head_begin,head_end):
  211. if inner_table[j][i][0] not in data_set_is and inner_table[j][i][0] not in data_set_no:
  212. data.append([file,inner_table[j][i][0],1])
  213. data_set_is.add(inner_table[j][i][0])
  214. inner_table[j][i][1] = 2
  215. for line in inner_table:
  216. for item in line:
  217. if item[0] not in data_set_is and item[0] not in data_set_no:
  218. data.append([file,item[0],0])
  219. data_set_no.add(item[0])
  220. addPadding(inner_table, pad_row, pad_col)
  221. return inner_table,head_list
  222. #设置表头
  223. def setHead_withRule(inner_table,pattern,pat_value,count):
  224. height = len(inner_table)
  225. width = len(inner_table[0])
  226. head_list = []
  227. head_list.append(0)
  228. #行表头
  229. is_head_last = False
  230. for i in range(height):
  231. set_match = set()
  232. is_head = False
  233. is_long_value = False
  234. is_same_value = True
  235. same_value = inner_table[i][0][0]
  236. for j in range(width):
  237. if inner_table[i][j][0]!=same_value:
  238. is_same_value = False
  239. break
  240. for j in range(width):
  241. if re.search(pat_value,inner_table[i][j][0]) is not None:
  242. is_head = False
  243. break
  244. str_find = re.findall(pattern,inner_table[i][j][0])
  245. if len(str_find)>0:
  246. set_match.add(inner_table[i][j][0])
  247. if len(set_match)>=count:
  248. is_head = True
  249. if len(inner_table[i][0][0])>40:
  250. is_long_value = True
  251. if is_head or is_long_value or is_same_value:
  252. if not is_head_last:
  253. head_list.append(i)
  254. if is_head:
  255. for j in range(width):
  256. inner_table[i][j][1] = 1
  257. is_head_last = is_head
  258. head_list.append(height)
  259. #列表头
  260. for i in range(len(head_list)-1):
  261. head_begin = head_list[i]
  262. head_end = head_list[i+1]
  263. #最后一列不设置为列表头
  264. for i in range(width-1):
  265. set_match = set()
  266. is_head = False
  267. for j in range(head_begin,head_end):
  268. if re.search(pat_value,inner_table[j][i][0]) is not None:
  269. is_head = False
  270. break
  271. str_find = re.findall(pattern,inner_table[j][i][0])
  272. if len(str_find)>0:
  273. set_match.add(inner_table[j][i][0])
  274. if len(set_match)>=count:
  275. is_head = True
  276. if is_head:
  277. for j in range(head_begin,head_end):
  278. inner_table[j][i][1] = 2
  279. return inner_table,head_list
  280. #取得表格的处理方向
  281. def getDirect(inner_table,begin,end):
  282. column_head = set()
  283. row_head = set()
  284. widths = len(inner_table[0])
  285. for height in range(begin,end):
  286. for width in range(widths):
  287. if inner_table[height][width][1] ==1:
  288. row_head.add(height)
  289. if inner_table[height][width][1] ==2:
  290. column_head.add(width)
  291. company_pattern = re.compile("公司")
  292. if 0 in column_head and begin not in row_head:
  293. return "column"
  294. if 0 in column_head and begin in row_head:
  295. for height in range(begin,end):
  296. count = 0
  297. count_flag = True
  298. for width_index in range(width):
  299. if inner_table[height][width_index][1]==0:
  300. if re.search(company_pattern,inner_table[height][width_index][0]) is not None:
  301. count += 1
  302. else:
  303. count_flag = False
  304. if count_flag and count>=2:
  305. return "column"
  306. return "row"
  307. #根据表格处理方向生成句子,
  308. def getTableText(inner_table,head_list):
  309. rankPattern = "(排名|排序|名次|评标结果|评审结果)"
  310. entityPattern = "(候选|([中投]标|报价)(人|单位|候选)|单位名称|供应商)"
  311. height = len(inner_table)
  312. width = len(inner_table[0])
  313. text = ""
  314. for head_i in range(len(head_list)-1):
  315. head_begin = head_list[head_i]
  316. head_end = head_list[head_i+1]
  317. direct = getDirect(inner_table, head_begin, head_end)
  318. if direct=="row":
  319. for i in range(head_begin,head_end):
  320. rank_text = ""
  321. entity_text = ""
  322. text_line = ""
  323. #在同一句话中重复的可以去掉
  324. text_set = set()
  325. for j in range(width):
  326. cell = inner_table[i][j]
  327. #是属性值
  328. if cell[1]==0:
  329. find_flag = False
  330. head = ""
  331. temp_head = ""
  332. for loop_j in range(1,j+1):
  333. if inner_table[i][j-loop_j][1]==2:
  334. if find_flag:
  335. if inner_table[i][j-loop_j][0]!=temp_head:
  336. head = inner_table[i][j-loop_j][0]+":"+head
  337. else:
  338. head = inner_table[i][j-loop_j][0]+":"+head
  339. find_flag = True
  340. temp_head = inner_table[i][j-loop_j][0]
  341. else:
  342. if find_flag:
  343. break
  344. find_flag = False
  345. temp_head = ""
  346. for loop_i in range(0,i+1-head_begin):
  347. if inner_table[i-loop_i][j][1]==1:
  348. if find_flag:
  349. if inner_table[i-loop_i][j][0]!=temp_head:
  350. head = inner_table[i-loop_i][j][0]+":"+head
  351. else:
  352. head = inner_table[i-loop_i][j][0]+":"+head
  353. find_flag = True
  354. temp_head = inner_table[i-loop_i][j][0]
  355. else:
  356. #找到表头后遇到属性值就返回
  357. if find_flag:
  358. break
  359. if str(head+inner_table[i][j][0]) in text_set:
  360. continue
  361. if re.search(rankPattern,head) is not None:
  362. rank_text += head+inner_table[i][j][0]+","
  363. #print(rank_text)
  364. elif re.search(entityPattern,head) is not None:
  365. entity_text += head+inner_table[i][j][0]+","
  366. #print(entity_text)
  367. else:
  368. text_line += head+inner_table[i][j][0]+","
  369. text_set.add(str(head+inner_table[i][j][0]))
  370. text += rank_text+entity_text+text_line
  371. text = text[:-1]+"。"
  372. else:
  373. for j in range(width):
  374. rank_text = ""
  375. entity_text = ""
  376. text_line = ""
  377. text_set = set()
  378. for i in range(head_begin,head_end):
  379. cell = inner_table[i][j]
  380. #是属性值
  381. if cell[1]==0:
  382. find_flag = False
  383. head = ""
  384. temp_head = ""
  385. for loop_j in range(1,j+1):
  386. if inner_table[i][j-loop_j][1]==2:
  387. if find_flag:
  388. if inner_table[i][j-loop_j][0]!=temp_head:
  389. head = inner_table[i][j-loop_j][0]+":"+head
  390. else:
  391. head = inner_table[i][j-loop_j][0]+":"+head
  392. find_flag = True
  393. temp_head = inner_table[i][j-loop_j][0]
  394. else:
  395. if find_flag:
  396. break
  397. find_flag = False
  398. temp_head = ""
  399. for loop_i in range(0,i+1-head_begin):
  400. if inner_table[i-loop_i][j][1]==1:
  401. if find_flag:
  402. if inner_table[i-loop_i][j][0]!=temp_head:
  403. head = inner_table[i-loop_i][j][0]+":"+head
  404. else:
  405. head = inner_table[i-loop_i][j][0]+":"+head
  406. find_flag = True
  407. temp_head = inner_table[i-loop_i][j][0]
  408. else:
  409. if find_flag:
  410. break
  411. if str(head+inner_table[i][j][0]) in text_set:
  412. continue
  413. if re.search(rankPattern,head) is not None:
  414. rank_text += head+inner_table[i][j][0]+","
  415. #print(rank_text)
  416. elif re.search(entityPattern,head) is not None:
  417. entity_text += head+inner_table[i][j][0]+","
  418. #print(entity_text)
  419. else:
  420. text_line += head+inner_table[i][j][0]+","
  421. text_set.add(str(head+inner_table[i][j][0]))
  422. text += rank_text+entity_text+text_line
  423. text = text[:-1]+"。"
  424. return text
  425. def trunTable(tbody):
  426. fixSpan(tbody)
  427. inner_table = getTable(tbody)
  428. inner_table = fixTable(inner_table)
  429. if len(inner_table)>0 and len(inner_table[0])>0:
  430. #inner_table,head_list = setHead_withRule(inner_table,pat_head,pat_value,3)
  431. inner_table,head_list = setHead(inner_table)
  432. '''
  433. print("----")
  434. print(head_list)
  435. for item in inner_table:
  436. print(item)
  437. '''
  438. tbody.string = getTableText(inner_table,head_list)
  439. #print(tbody.string)
  440. tbody.name = "table"
  441. pat_head = re.compile('(名称|序号|项目|标项|工程|品目[一二三四1234]|第[一二三四1234](标段|名|候选人|中标)|包段|包号|货物|单位|数量|价格|报价|金额|总价|单价|[招投中]标|供应商|候选|编号|得分|评委|评分|名次|排名|排序|科室|方式|工期|时间|产品|开始|结束|联系|日期|面积|姓名|证号|备注|级别|地[点址]|类型|代理|制造)')
  442. #pat_head = re.compile('(名称|序号|项目|工程|品目[一二三四1234]|第[一二三四1234](标段|候选人|中标)|包段|包号|货物|单位|数量|价格|报价|金额|总价|单价|[招投中]标|供应商|候选|编号|得分|评委|评分|名次|排名|排序|科室|方式|工期|时间|产品|开始|结束|联系|日期|面积|姓名|证号|备注|级别|地[点址]|类型|代理)')
  443. pat_value = re.compile("(\d{2,}.\d{1}|\d+年\d+月|\d{8,}|\d{3,}-\d{6,}|有限[责任]*公司|^\d+$)")
  444. tbodies = soup.find_all('table')
  445. # 遍历表格中的每个tbody
  446. #逆序处理嵌套表格
  447. for tbody_index in range(1,len(tbodies)+1):
  448. tbody = tbodies[len(tbodies)-tbody_index]
  449. trunTable(tbody)
  450. tbodies = soup.find_all('tbody')
  451. # 遍历表格中的每个tbody
  452. #逆序处理嵌套表格
  453. for tbody_index in range(1,len(tbodies)+1):
  454. tbody = tbodies[len(tbodies)-tbody_index]
  455. trunTable(tbody)
  456. return soup
  457. def getSourceData():
  458. data = []
  459. data_set_is = set()
  460. data_set_no = set()
  461. for file in glob.glob("C:\\Users\\User\\Desktop\\20190320要素\\*.html"):
  462. filename = file.split("\\")[-1]
  463. source = codecs.open(file,"r",encoding="utf8").read()
  464. tableToText(BeautifulSoup(source,"lxml"),data,filename,data_set_is,data_set_no)
  465. for file in glob.glob("C:\\Users\\User\\Desktop\\20190306要素\\*.html"):
  466. filename = file.split("\\")[-1]
  467. source = codecs.open(file,"r",encoding="utf8").read()
  468. tableToText(BeautifulSoup(source,"lxml"),data,filename,data_set_is,data_set_no)
  469. ''''''
  470. list_file = []
  471. list_item = []
  472. list_label = []
  473. #data.sort(key=lambda x:x[2],reverse=True)
  474. data = data[0:60000]
  475. for item in data:
  476. list_file.append(item[0])
  477. list_item.append(item[1][:100])
  478. list_label.append(item[2])
  479. df = pd.DataFrame({"list_file":list_file,"list_item":list_item,"list_label":list_label})
  480. df.to_excel("data_item.xls",columns=["list_file","list_item","list_label"])
  481. def importData():
  482. conn = psycopg2.connect(dbname="article_label",user="postgres",password="postgres",host="192.168.2.101")
  483. cursor = conn.cursor()
  484. file = "data_item.xls"
  485. df = pd.read_excel(file)
  486. for file,text,label in zip(df["list_file"],df["list_item"],df["list_label"]):
  487. text = str(text)
  488. text = text.replace("\\","\\\\")
  489. text = re.sub("'","\\'",str(text))
  490. sql = " insert into form(filename,text,label) values(E'"+file+"',E'"+str(text)+"',E'"+str(int(label))+"')"
  491. print(sql)
  492. cursor.execute(sql)
  493. conn.commit()
  494. conn.close()
  495. def selectWithRule(source,filter,target):
  496. assert source!=target
  497. dict_source = pd.read_excel(source)
  498. set_filter = set()
  499. for filt in filter:
  500. set_filter = set_filter | set(pd.read_excel(filt)["list_item"])
  501. list_file = []
  502. list_item = []
  503. list_label = []
  504. for file,text,label in zip(dict_source["list_file"],dict_source["list_item"],dict_source["list_label"]):
  505. if str(text) in set_filter:
  506. continue
  507. if re.search(".{8,}(工程|项目|采购|公告|公示)",str(text)) is not None:
  508. #if len(str(text))>20:
  509. list_file.append(file)
  510. list_item.append(text)
  511. list_label.append(label)
  512. data = {"list_file":list_file,"list_item":list_item,"list_label":list_label}
  513. columns = ["list_file","list_item","list_label"]
  514. df = pd.DataFrame(data)
  515. df.to_excel(target,index=False,columns=columns)
  516. def importRelabel():
  517. files = ["批量.xls"]
  518. conn = psycopg2.connect(dbname="article_label",user="postgres",password="postgres",host="192.168.2.101")
  519. cursor = conn.cursor()
  520. for file in files:
  521. df = pd.read_excel(file)
  522. for text,relabel in zip(df["list_item"],df["list_relabel"]):
  523. text = str(text)
  524. text = text.replace("\\","\\\\")
  525. text = re.sub("'","\\'",str(text))
  526. sql = " update form set relabel='"+str(int(relabel))+"' where text=E'"+str(text)+"' "
  527. cursor.execute(sql)
  528. conn.commit()
  529. conn.close()
  530. def getHtml():
  531. conn = psycopg2.connect(dbname="article_label",user="postgres",password="postgres",host="192.168.2.101")
  532. cursor = conn.cursor()
  533. sql = " select filename from form where relabel is NULL group by filename having count(1)>0 "
  534. cursor.execute(sql)
  535. rows = cursor.fetchall()
  536. data = []
  537. index = 0
  538. for row in rows:
  539. filename = row[0]
  540. if filename=="比地_101_58519594.html":
  541. print(index)
  542. path = "C:\\Users\\User\\Desktop\\20190320要素\\"+filename
  543. if not os.path.exists(path):
  544. path = "C:\\Users\\User\\Desktop\\20190306要素\\"+filename
  545. data.append([filename,codecs.open(path,'r',encoding="utf8").read()])
  546. index += 1
  547. #save(data,"namehtml.pk")
  548. def getTrainData(percent=0.9):
  549. conn = psycopg2.connect(dbname="article_label",user="postgres",password="postgres",host="192.168.2.101")
  550. cursor = conn.cursor()
  551. sql = "select filename,text,label,relabel,handlabel from form "
  552. cursor.execute(sql)
  553. rows = cursor.fetchall()
  554. save(rows,"filename_text_label_relabel_handlabel.pk")
  555. train_x = []
  556. train_y = []
  557. test_x = []
  558. test_y = []
  559. test_text = []
  560. for row in rows:
  561. input = str(row[1])
  562. label = str(int(row[2]))
  563. if row[4] is not None:
  564. label = str(int(row[4]))
  565. elif row[3] is not None:
  566. label = str(int(row[3]))
  567. item_y = [0,0]
  568. item_y[int(label)] = 1
  569. if np.random.random()<percent:
  570. # train_x.append(encodeInput(input))
  571. train_x.append(encodeInput([input], word_len=50, word_flag=True,userFool=False)[0])
  572. train_y.append(item_y)
  573. else:
  574. # test_x.append(encodeInput(input))
  575. test_x.append(encodeInput([input], word_len=50, word_flag=True,userFool=False)[0])
  576. test_y.append(item_y)
  577. test_text.append([row[0],input])
  578. return np.array(train_x),np.array(train_y),np.array(test_x),np.array(test_y),test_text
  579. def getTrainData_jsonTable(begin,end,return_text=False):
  580. def encode_table(inner_table,size=30):
  581. def encode_item(_table,i,j):
  582. _x = [_table[j-1][i-1],_table[j-1][i],_table[j-1][i+1],
  583. _table[j][i-1],_table[j][i],_table[j][i+1],
  584. _table[j+1][i-1],_table[j+1][i],_table[j+1][i+1]]
  585. e_x = [encodeInput_form(_temp[0],MAX_LEN=30) for _temp in _x]
  586. _label = _table[j][i][1]
  587. # print(_x)
  588. # print(_x[4],_label)
  589. return e_x,_label,_x
  590. def copytable(inner_table):
  591. table = []
  592. for line in inner_table:
  593. list_line = []
  594. for item in line:
  595. list_line.append([item[0][:size],item[1]])
  596. table.append(list_line)
  597. return table
  598. table = copytable(inner_table)
  599. padding = ["#"*30,0]
  600. width = len(table[0])
  601. height = len(table)
  602. table.insert(0,[padding for i in range(width)])
  603. table.append([padding for i in range(width)])
  604. for item in table:
  605. item.insert(0,padding.copy())
  606. item.append(padding.copy())
  607. data_x = []
  608. data_y = []
  609. data_text = []
  610. data_position = []
  611. for _i in range(1,width+1):
  612. for _j in range(1,height+1):
  613. _x,_y,_text = encode_item(table,_i,_j)
  614. data_x.append(_x)
  615. _label = [0,0]
  616. _label[_y] = 1
  617. data_y.append(_label)
  618. data_text.append(_text)
  619. data_position.append([_i-1,_j-1])
  620. # input = table[_j][_i][0]
  621. # item_y = [0,0]
  622. # item_y[table[_j][_i][1]] = 1
  623. # data_x.append(encodeInput([input], word_len=50, word_flag=True,userFool=False)[0])
  624. # data_y.append(item_y)
  625. return data_x,data_y,data_text,data_position
  626. def getDataSet(list_json_table,return_text=False):
  627. _count = 0
  628. _sum = len(list_json_table)
  629. data_x = []
  630. data_y = []
  631. data_text = []
  632. for json_table in list_json_table:
  633. _count += 1
  634. print("%d/%d"%(_count,_sum))
  635. table = json.loads(json_table)
  636. if table is not None:
  637. list_x,list_y,list_text = encode_table(table)
  638. data_x.extend(list_x)
  639. data_y.extend(list_y)
  640. if return_text:
  641. data_text.extend(list_text)
  642. return np.array(data_x),np.array(data_y),data_text
  643. save_path = "./traindata/websource_67000_table_%d-%d-%s.pk"%(begin,end,"1" if return_text else "0")
  644. if os.path.exists(save_path):
  645. data_x,data_y,data_text = load(save_path)
  646. else:
  647. df = pd.read_csv("../../dl_dev/form/traindata/websource_67000_table.csv", encoding="GBK")
  648. import json
  649. data_x,data_y,data_text = getDataSet(df["json_table"][begin:end],return_text=return_text)
  650. save((data_x,data_y,data_text),save_path)
  651. return data_x,data_y,data_text
  652. if __name__=="__main__":
  653. #getSourceData()
  654. #importData()
  655. #selectWithRule("data_item.xls", ["批量.xls"], "temp.xls")
  656. #importRelabel()
  657. # getHtml()
  658. getTrainData_jsonTable()