inference.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. import copy
  2. import random
  3. from glob import glob
  4. import cv2
  5. from sklearn.model_selection import train_test_split
  6. from image import gen
  7. from metrics import precision, recall, f1
  8. from table_line import model, focal_loss
  9. import numpy as np
  10. random.seed(42)
  11. def run(is_test=1):
  12. # paths = glob('train/dataset-line/6/*.json')
  13. paths = glob('C:/Table_Label/clean_data/*.json')
  14. paths = random.sample(paths, 676)
  15. print("len(paths)", len(paths))
  16. model_path = "./models/e17-f10.93.h5"
  17. # model_path = "./models/table-line_old.h5"
  18. model.load_weights(model_path)
  19. data_loader = gen(paths, batchsize=1, linetype=2, is_test=True)
  20. i = 0
  21. data_len = len(paths)
  22. p_score_total = 0
  23. r_score_total = 0
  24. f1_score_total = 0
  25. for X, Y, img_np in data_loader:
  26. if i >= len(paths):
  27. break
  28. pred = model.predict(X)
  29. pred = pred.astype(np.float64)
  30. if is_test:
  31. if i >= 6000:
  32. cv2.namedWindow("1", 0)
  33. cv2.resizeWindow("1", 1000, 800)
  34. cv2.imshow("1", img_np)
  35. draw_pixel(pred[0], 0.2, is_test)
  36. p_score = round(precision(Y, pred).numpy(), 3)
  37. r_score = round(recall(Y, pred).numpy(), 3)
  38. f1_score = round(f1(Y, pred).numpy(), 3)
  39. print("precision", p_score, "recall", r_score, "f1", f1_score, "batch", i, data_len)
  40. p_score_total += p_score
  41. r_score_total += r_score
  42. f1_score_total += f1_score
  43. i += 1
  44. print("precision avg", p_score_total/data_len)
  45. print("recall avg", r_score_total/data_len)
  46. print("f1 avg", f1_score_total/data_len)
  47. def draw_pixel(pred, prob=0.2, is_test=1):
  48. if not is_test:
  49. return
  50. else:
  51. import matplotlib.pyplot as plt
  52. _array = []
  53. for _h in range(len(pred)):
  54. _line = []
  55. for _w in range(len(pred[_h])):
  56. _prob = pred[_h][_w]
  57. if _prob[0] > prob:
  58. _line.append((0, 0, 255))
  59. elif _prob[1] > prob:
  60. _line.append((255, 0, 0))
  61. else:
  62. _line.append((255, 255, 255))
  63. _array.append(_line)
  64. plt.axis('off')
  65. plt.imshow(np.array(_array))
  66. plt.show()
  67. return
  68. if __name__ == "__main__":
  69. run()