check_img_labeled.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267
  1. import base64
  2. import json
  3. import os
  4. import traceback
  5. import cv2
  6. import numpy as np
  7. label_img_dir = "C:/Table_Label/"
  8. def get_lines(line_cnt=None):
  9. json_list = []
  10. for root, dirs, files in os.walk(label_img_dir, topdown=False):
  11. for name in files:
  12. if name.split(".")[-1] == 'json':
  13. json_list.append(root + os.path.sep + name)
  14. if line_cnt is not None:
  15. json_list = json_list[:line_cnt]
  16. print("len(json_list)", len(json_list))
  17. # 解析json
  18. lines_list = []
  19. for j in json_list:
  20. try:
  21. with open(j, "r") as f:
  22. _dict = json.loads(f.read())
  23. _list = _dict.get("shapes")
  24. line_list = []
  25. for d in _list:
  26. points = d.get("points")
  27. if len(points) < 2:
  28. continue
  29. if points[0][0] <= points[1][0]:
  30. line_list.append([round(points[0][0], 2), round(points[0][1], 2),
  31. round(points[1][0], 2), round(points[1][1], 2)])
  32. else:
  33. line_list.append([round(points[1][0], 2), round(points[1][1], 2),
  34. round(points[0][0], 2), round(points[0][1], 2)])
  35. line_list.sort(key=lambda x: (x[1], x[3], x[0], x[2]))
  36. lines_list.append([j, line_list])
  37. except Exception as e:
  38. traceback.print_exc()
  39. print("error path", j)
  40. continue
  41. return lines_list
  42. def get_angles(lines_list):
  43. angles_list = []
  44. for _path, lines in lines_list:
  45. angles = []
  46. for line in lines:
  47. x1, y1, x2, y2 = line
  48. if x1 == x2:
  49. k = None
  50. angle = 1.57
  51. else:
  52. k = abs((y2 - y1) / (x2 - x1))
  53. angle = np.arctan(k)
  54. angles.append([angle, line])
  55. angles_list.append([_path, angles])
  56. return angles_list
  57. def get_combine_lines(angles_list):
  58. def judge_axis_distance(l1, l2, d, axis=0):
  59. if axis:
  60. if abs(l1[1] - l2[1]) <= d:
  61. return True
  62. if abs(l1[3] - l2[1]) <= d:
  63. return True
  64. if abs(l1[1] - l2[3]) <= d:
  65. return True
  66. if abs(l1[3] - l2[3]) <= d:
  67. return True
  68. else:
  69. if abs(l1[0] - l2[0]) <= d:
  70. return True
  71. if abs(l1[2] - l2[0]) <= d:
  72. return True
  73. if abs(l1[0] - l2[2]) <= d:
  74. return True
  75. if abs(l1[2] - l2[2]) <= d:
  76. return True
  77. return False
  78. def judge_point_distance(l1, l2, d):
  79. for m in range(0, len(l1), 2):
  80. x1, y1 = l1[m], l1[m+1]
  81. for n in range(0, len(l2), 2):
  82. x2, y2 = l2[n], l2[n+1]
  83. if pow(pow(x1-x2, 2) + pow(y1-y2, 2), 0.5) <= d:
  84. return True
  85. angle_threshold = 8
  86. distance_threshold = 5
  87. combine_threshold = 20
  88. all_combined_line_list = []
  89. cnt = 0
  90. for _path, angles in angles_list:
  91. # if cnt % 10 == 0:
  92. # print("Loop", cnt)
  93. # cnt += 1
  94. # print("path", _path)
  95. # 获取需合并线
  96. one_image_equal_lines = []
  97. has_equal_lines = []
  98. for i in range(len(angles)):
  99. if angles[i] in has_equal_lines:
  100. continue
  101. equal_lines = []
  102. angle1 = angles[i][0]
  103. line1 = angles[i][1]
  104. row_or_col_1 = abs(line1[0] - line1[2]) - abs(line1[1] - line1[3])
  105. for j in range(i+1, len(angles)):
  106. if angles[i] in has_equal_lines:
  107. continue
  108. angle2 = angles[j][0]
  109. line2 = angles[j][1]
  110. row_or_col_2 = abs(line2[0] - line2[2]) - abs(line2[1] - line2[3])
  111. if row_or_col_1 * row_or_col_2 <= 0:
  112. continue
  113. # 判断角度相差
  114. if abs(angle1 - angle2) <= angle_threshold:
  115. # 判断距离
  116. if row_or_col_2 >= 0:
  117. if judge_axis_distance(line1, line2, combine_threshold, axis=1):
  118. if judge_point_distance(line1, line2, distance_threshold):
  119. equal_lines.append(angles[j])
  120. # has_equal_lines.append(angles[j])
  121. else:
  122. if judge_axis_distance(line1, line2, combine_threshold, axis=0):
  123. if judge_point_distance(line1, line2, distance_threshold):
  124. equal_lines.append(angles[j])
  125. # has_equal_lines.append(angles[j])
  126. equal_lines.append(angles[i])
  127. # has_equal_lines.append(angles[i])
  128. # print("equal_lines", equal_lines)
  129. one_image_equal_lines.append(equal_lines)
  130. # print("one_image_equal_lines", one_image_equal_lines)
  131. # 对组进行合并
  132. combined_group = []
  133. for group1 in one_image_equal_lines:
  134. for group2 in one_image_equal_lines:
  135. # print(group1, "-", group2)
  136. find_flag = False
  137. for line in group2:
  138. if line in group1:
  139. group1 += group2
  140. find_flag = True
  141. break
  142. if find_flag:
  143. break
  144. group1 = [str(x) for x in group1]
  145. group1 = list(set(group1))
  146. group1 = [eval(x) for x in group1]
  147. combined_group.append(group1)
  148. combined_group = [str(x) for x in combined_group]
  149. combined_group = list(set(combined_group))
  150. combined_group = [eval(x) for x in combined_group]
  151. # 对符合条件的线合并
  152. combined_line_list = []
  153. for equal_lines in combined_group:
  154. if len(equal_lines) > 1:
  155. x_plus_y_min = 10000
  156. x_plus_y_max = 0
  157. point_max = (0, 0)
  158. point_min = (0, 0)
  159. for angle, line in equal_lines:
  160. # print("angle, line", angle, line)
  161. if line[0] + line[1] <= x_plus_y_min:
  162. x_plus_y_min = line[0] + line[1]
  163. point_min = (line[0], line[1])
  164. if line[0] + line[1] > x_plus_y_max:
  165. x_plus_y_max = line[0] + line[1]
  166. point_max = (line[0], line[1])
  167. if line[2] + line[3] <= x_plus_y_min:
  168. x_plus_y_min = line[2] + line[3]
  169. point_min = (line[2], line[3])
  170. if line[2] + line[3] > x_plus_y_max:
  171. x_plus_y_max = line[2] + line[3]
  172. point_max = (line[2], line[3])
  173. combined_line = [point_min[0], point_min[1], point_max[0], point_max[1]]
  174. else:
  175. combined_line = equal_lines[0][1]
  176. combined_line_list.append(combined_line)
  177. combined_line_list = [str(x) for x in combined_line_list]
  178. combined_line_list = list(set(combined_line_list))
  179. combined_line_list = [eval(x) for x in combined_line_list]
  180. combined_line_list.sort(key=lambda x: (x[1], x[3], x[0], x[2]))
  181. all_combined_line_list.append([_path, combined_line_list])
  182. return all_combined_line_list
  183. def show_lines(lines_list):
  184. for _path, lines in lines_list:
  185. with open(_path, "r") as f:
  186. _dict = json.loads(f.read())
  187. img_bytes = base64.b64decode(_dict.get("imageData").encode("utf-8"))
  188. img_np = np.frombuffer(img_bytes, np.uint8)
  189. img_cv = cv2.imdecode(img_np, cv2.IMREAD_ANYCOLOR)
  190. print("path", _path)
  191. print("img_cv.shape", img_cv.shape)
  192. cv2.namedWindow(_path, 0)
  193. cv2.resizeWindow(_path, 1000, 800)
  194. for line in lines:
  195. # 随机颜色
  196. color = np.random.randint(0, 255, 3, dtype=np.int32)
  197. color = (np.int(color[0]), np.int(color[1]), np.int(color[2]))
  198. cv2.line(img_cv,
  199. (int(line[0]), int(line[1])), (int(line[2]), int(line[3])),
  200. color, thickness=10)
  201. # print("show", line)
  202. cv2.imshow(_path, img_cv)
  203. cv2.waitKey(0)
  204. def to_json(lines_list):
  205. for _path, lines in lines_list:
  206. with open(_path, "r") as f:
  207. _dict = json.loads(f.read())
  208. new_shapes = []
  209. for line in lines:
  210. temp_dict = {"group_id": "null", "shape_type": "line", "flags": {}}
  211. temp_dict.update({"points": [[int(line[0]), int(line[1])], [int(line[2]), int(line[3])]]})
  212. if abs(line[0] - line[2]) >= abs(line[1] - line[3]):
  213. temp_dict.update({"label": "0"})
  214. else:
  215. temp_dict.update({"label": "1"})
  216. new_shapes.append(temp_dict)
  217. _dict["shapes"] = new_shapes
  218. new_path = "C:/Table_Label/clean_data/" + _path.split(os.sep)[-1]
  219. with open(new_path, "w") as f:
  220. f.write(json.dumps(_dict))
  221. def check():
  222. lines = get_lines()
  223. angles_list = get_angles(lines)
  224. # 分割数据处理
  225. batch = 100
  226. all_combined_lines = []
  227. for i in range(0, len(angles_list), batch):
  228. if i % 10 == 0:
  229. print("Loop", i)
  230. batch_list = angles_list[i:i+batch]
  231. combined_lines = get_combine_lines(batch_list)
  232. all_combined_lines += combined_lines
  233. to_json(all_combined_lines)
  234. # show_lines(combined_lines)
  235. if __name__ == '__main__':
  236. check()