123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384 |
- import copy
- import random
- from glob import glob
- import cv2
- from sklearn.model_selection import train_test_split
- from image import gen
- from metrics import precision, recall, f1
- from table_line import model, focal_loss
- import numpy as np
- random.seed(42)
- def run(is_test=1):
- # paths = glob('train/dataset-line/6/*.json')
- paths = glob('C:/Table_Label/clean_data/*.json')
- paths = random.sample(paths, 676)
- print("len(paths)", len(paths))
- model_path = "./models/e17-f10.93.h5"
- # model_path = "./models/table-line_old.h5"
- model.load_weights(model_path)
- data_loader = gen(paths, batchsize=1, linetype=2, is_test=True)
- i = 0
- data_len = len(paths)
- p_score_total = 0
- r_score_total = 0
- f1_score_total = 0
- for X, Y, img_np in data_loader:
- if i >= len(paths):
- break
- pred = model.predict(X)
- pred = pred.astype(np.float64)
- if is_test:
- if i >= 6000:
- cv2.namedWindow("1", 0)
- cv2.resizeWindow("1", 1000, 800)
- cv2.imshow("1", img_np)
- draw_pixel(pred[0], 0.2, is_test)
- p_score = round(precision(Y, pred).numpy(), 3)
- r_score = round(recall(Y, pred).numpy(), 3)
- f1_score = round(f1(Y, pred).numpy(), 3)
- print("precision", p_score, "recall", r_score, "f1", f1_score, "batch", i, data_len)
- p_score_total += p_score
- r_score_total += r_score
- f1_score_total += f1_score
- i += 1
- print("precision avg", p_score_total/data_len)
- print("recall avg", r_score_total/data_len)
- print("f1 avg", f1_score_total/data_len)
- def draw_pixel(pred, prob=0.2, is_test=1):
- if not is_test:
- return
- else:
- import matplotlib.pyplot as plt
- _array = []
- for _h in range(len(pred)):
- _line = []
- for _w in range(len(pred[_h])):
- _prob = pred[_h][_w]
- if _prob[0] > prob:
- _line.append((0, 0, 255))
- elif _prob[1] > prob:
- _line.append((255, 0, 0))
- else:
- _line.append((255, 255, 255))
- _array.append(_line)
- plt.axis('off')
- plt.imshow(np.array(_array))
- plt.show()
- return
- if __name__ == "__main__":
- run()
|