pre_process.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. import sys
  2. import os
  3. sys.path.append(os.path.abspath("../.."))
  4. import psycopg2
  5. import numpy as np
  6. def get_sentence_index_list(sentence, dict_path='utils/ppocr_keys_v1.txt'):
  7. with open(dict_path, 'r') as f:
  8. character_list = f.readlines()
  9. for i in range(len(character_list)):
  10. character_list[i] = character_list[i][:-1]
  11. index_list = []
  12. for character in sentence:
  13. if character == '':
  14. index_list.append(0)
  15. elif character in character_list:
  16. _index = character_list.index(character) + 1
  17. index_list.append(_index)
  18. else:
  19. index_list.append(0)
  20. return index_list
  21. def postgresql_util(sql, limit):
  22. conn = psycopg2.connect(dbname="table_head_label", user="postgres", password="postgres",
  23. host="192.168.2.103")
  24. cursor = conn.cursor()
  25. cursor.execute(sql)
  26. print(sql)
  27. rows = cursor.fetchmany(1000)
  28. cnt = 0
  29. all_rows = []
  30. while rows:
  31. if cnt >= limit:
  32. break
  33. all_rows += rows
  34. cnt += len(rows)
  35. rows = cursor.fetchmany(1000)
  36. return all_rows
  37. def get_data_from_sql(dim=10):
  38. sql = """
  39. select table_text, pre_label, post_label, id
  40. from label_table_head_info
  41. where update_user <> 'test27' and table_box_cnt >= 4 and table_box_cnt <= 200
  42. limit 1000;
  43. """
  44. # sql = """
  45. # select table_text, pre_label, post_label, id
  46. # from label_table_head_info
  47. # where id = 843
  48. # """
  49. result_list = postgresql_util(sql, limit=10000)
  50. all_data_list = []
  51. all_data_label_list = []
  52. i = 0
  53. # 一行就是一篇表格
  54. for table in result_list:
  55. i += 1
  56. if i % 100 == 0:
  57. print("Loop", i)
  58. pre_label = eval(table[1])
  59. post_label = eval(table[2])
  60. _id = table[3]
  61. # table_text需要特殊处理
  62. try:
  63. table_text = table[0]
  64. if table_text[0] == '"':
  65. table_text = eval(table_text)
  66. else:
  67. table_text = table_text
  68. table_text = table_text.replace('\\', '/')
  69. table_text = eval(table_text)
  70. except:
  71. print("无法识别table_text", _id)
  72. continue
  73. # 只有一行的也不要
  74. if len(post_label) >= 2:
  75. data_list, data_label_list = table_process(table_text, post_label, _id)
  76. elif len(pre_label) >= 2:
  77. data_list, data_label_list = table_process(table_text, pre_label, _id)
  78. else:
  79. data_list, data_label_list = [], []
  80. for data in data_list:
  81. # 中文字符映射为index
  82. data[0] = get_sentence_index_list(data[0])
  83. data[1] = get_sentence_index_list(data[1])
  84. # 维度不够,填充掩码0
  85. if len(data[0]) < dim:
  86. data[0] = data[0] + [0]*(dim-len(data[0]))
  87. elif len(data[0]) > dim:
  88. data[0] = data[0][:dim]
  89. if len(data[1]) < dim:
  90. data[1] = data[1] + [0]*(dim-len(data[1]))
  91. elif len(data[1]) > dim:
  92. data[1] = data[1][:dim]
  93. all_data_list += data_list
  94. all_data_label_list += data_label_list
  95. return all_data_list, all_data_label_list
  96. def table_process(text_list, label_list, _id):
  97. if len(text_list) != len(label_list):
  98. print("文字单元格与标注单元格数量不匹配!", _id)
  99. print("len(text_list)", len(text_list), "len(label_list)", len(label_list))
  100. return [], []
  101. data_list = []
  102. data_label_list = []
  103. for i in range(len(text_list)):
  104. row = text_list[i]
  105. row_label = label_list[i]
  106. if i < len(text_list) - 1:
  107. next_row = text_list[i+1]
  108. next_row_label = label_list[i+1]
  109. else:
  110. next_row = []
  111. next_row_label = []
  112. for j in range(len(row)):
  113. col = row[j]
  114. col_label = row_label[j]
  115. if j < len(row) - 1:
  116. next_col = row[j+1]
  117. next_col_label = row_label[j+1]
  118. else:
  119. next_col = ""
  120. next_col_label = ""
  121. if next_row:
  122. next_row_col = next_row[j]
  123. next_row_col_label = next_row_label[j]
  124. else:
  125. next_row_col = ""
  126. next_row_col_label = ""
  127. if next_col:
  128. if col != next_col:
  129. data_list.append([col, next_col])
  130. data_label_list.append([int(col_label), int(next_col_label)])
  131. if next_row_col:
  132. if col != next_row_col:
  133. data_list.append([col, next_row_col])
  134. data_label_list.append([int(col_label), int(next_row_col_label)])
  135. return data_list, data_label_list
  136. def get_data_from_file():
  137. data_path = 'train_data/data.txt'
  138. data_label_path = 'train_data/data_label.txt'
  139. with open(data_path, 'r') as f:
  140. data_list = f.readlines()
  141. with open(data_label_path, 'r') as f:
  142. data_label_list = f.readlines()
  143. for i in range(len(data_list)):
  144. data_list[i] = eval(data_list[i][:-1])
  145. data_label_list[i] = eval(data_label_list[i][:-1])
  146. print(len(data_list))
  147. return data_list, data_label_list
  148. def processed_save_to_txt():
  149. list1, list2 = get_data_from_sql()
  150. with open('train_data/data.txt', 'w') as f:
  151. for line in list1:
  152. f.write(str(line) + "\n")
  153. with open('train_data/data_label.txt', 'w') as f:
  154. for line in list2:
  155. f.write(str(line) + "\n")
  156. if __name__ == '__main__':
  157. get_data_from_file()