train.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. import module.model as model
  2. import featureEngine
  3. from keras.callbacks import ModelCheckpoint
  4. import numpy as np
  5. import pickle
  6. import os
  7. from module.Utils import *
  8. from keras import models
  9. def save(object_to_save, path):
  10. '''
  11. 保存对象
  12. @Arugs:
  13. object_to_save: 需要保存的对象
  14. @Return:
  15. 保存的路径
  16. '''
  17. with open(path, 'wb') as f:
  18. pickle.dump(object_to_save, f)
  19. def load(path):
  20. '''
  21. 读取对象
  22. @Arugs:
  23. path: 读取的路径
  24. @Return:
  25. 读取的对象
  26. '''
  27. with open(path, 'rb') as f:
  28. object = pickle.load(f)
  29. return object
  30. def train():
  31. pk_file = "iterator/data_28849_16.pk"
  32. if os.path.exists(pk_file):
  33. data = load(pk_file)
  34. #data = featureEngine.paddinig(data, pad=False)
  35. #data[1] = np.argmax(data[1],-1)
  36. #print(np.shape(data[0]))
  37. else:
  38. data = featureEngine.getAllData()
  39. save(data,"data_"+str(len(data[1]))+".pk")
  40. model1 = model.getBiRNNModel()
  41. #model1.load_weights("../model_data/ep028-loss0.062-val_loss0.102-f10.9624.h5")
  42. model_file = "contentExtract.h5"
  43. log_dir = "log/"
  44. train_percent = 0.8
  45. test_percent=0.9
  46. print(np.shape(data[0]))
  47. train_len = round(len(data[0])*train_percent)
  48. test_len = round(len(data[0])*test_percent)
  49. checkpoint = ModelCheckpoint(log_dir + 'ep{epoch:03d}-loss{loss:.3f}-val_acc{val_acc:.3f}-val_loss{val_loss:.3f}-f1{val_f1_score:.4f}.h5',
  50. monitor='loss', save_best_only=True, period=1,mode="min")
  51. history_model = model1.fit(x=data[0][:train_len],y=data[1][:train_len],validation_data=(data[0][train_len:test_len],data[1][train_len:test_len]),epochs=400,batch_size=256,shuffle=True,callbacks=[checkpoint])
  52. def predict(x):
  53. '''
  54. model1 = model.getBiRNNModel()
  55. model1.load_weights("../model_data/ep133-loss-0.991-val_acc0.972-val_loss-0.951-f10.3121.h5")
  56. '''
  57. path = "log/ep011-loss0.160-val_acc0.900-val_loss0.156-f10.4536.h5"
  58. model1 = models.load_model(path, custom_objects={"acc":acc,"precision":precision,"recall":recall,"f1_score":f1_score,"my_loss":my_loss})
  59. return model1.predict(x,batch_size=1)
  60. def test(url):
  61. os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  62. os.environ["CUDA_VISIBLE_DEVICES"] = ""
  63. data = featureEngine.getInput_byJS(url)
  64. if data is not None:
  65. x,list_inner,list_xpath,_ = data
  66. print("x:",x)
  67. p = predict(x)
  68. print(p)
  69. print(np.argmax(p,1))
  70. print(p[0][np.argmax(p,1)[0][1]])
  71. print(list_inner[np.argmax(p,1)[0][1]])
  72. #print(list_inner[4])
  73. print(list_xpath[np.argmax(p,1)[0][1]])
  74. def val():
  75. pk_file = "iterator/data_28849_16.pk"
  76. data = load(pk_file)
  77. train_percent = 0.9
  78. train_len = round(len(data[0])*train_percent)
  79. #print(np.shape(data))
  80. predict_y = np.argmax(predict(data[0][train_len:]),1)
  81. label_y = np.argmax(data[1][train_len:],1)
  82. list_url = data[2][train_len:]
  83. size_predict = 0
  84. size_considence = 0
  85. dict_root_true_wrong = dict()
  86. for _predict,_label,_url in zip(predict_y,label_y,list_url):
  87. root = _url.split("/")[2]
  88. if root not in dict_root_true_wrong:
  89. dict_root_true_wrong[root] = [0,0]
  90. if _predict[1]==_label[1]:
  91. size_considence += 1
  92. dict_root_true_wrong[root][0] += 1
  93. else:
  94. dict_root_true_wrong[root][1] += 1
  95. print(_url)
  96. size_predict += 1
  97. list_root_true_wrong = []
  98. for _key in dict_root_true_wrong.keys():
  99. list_root_true_wrong.append([_key,dict_root_true_wrong[_key]])
  100. list_root_true_wrong.sort(key=lambda x:x[1][1]/(x[1][0]+x[1][1]))
  101. print(list_root_true_wrong)
  102. print(size_considence,size_predict)
  103. def iteratorLabel():
  104. '''
  105. @summary: 迭代地进行数据的修复
  106. '''
  107. data_file = "iterator/data_28849_35.pk"
  108. threshold = 0.93
  109. train_epochs = 10
  110. batch_size=96
  111. data = load(data_file)
  112. data_split = round(len(data[1])*0.5)
  113. last_change_set = set()
  114. this_change_set = set()
  115. max_not_change_times = 10
  116. _not_change_times = 0
  117. _time = 0
  118. while(True):
  119. _time += 1
  120. #训练模型
  121. model_1 = model.getBiRNNModel()
  122. model_2 = model.getBiRNNModel()
  123. model_1.fit(x=data[0][:data_split],y=data[1][:data_split],epochs=train_epochs,batch_size=batch_size,shuffle=True)
  124. model_2.fit(x=data[0][data_split:],y=data[1][data_split:],epochs=train_epochs,batch_size=batch_size,shuffle=True)
  125. predict_1 = model_1.predict(data[0])
  126. predict_2 = model_2.predict(data[0])
  127. _index = 0
  128. for _max_1,_max_2,_y1,_y2,Y,_url in zip(np.max(predict_1,1),np.max(predict_2,1),np.argmax(predict_1,1),np.argmax(predict_2,1),np.argmax(data[1],1),data[2]):
  129. if _y1[1]==_y2[1] and _y1[1]!=Y[1] and _max_1[1]>threshold and _max_2[1]>threshold:
  130. #修改标注
  131. data[1][_index][Y[1]] = 0
  132. data[1][_index][_y1[1]] = 1
  133. this_change_set.add(_url)
  134. _index += 1
  135. if len(this_change_set-last_change_set)<10:
  136. _not_change_times += 1
  137. else:
  138. _not_change_times = 0
  139. if _not_change_times>=max_not_change_times:
  140. break
  141. last_change_set = this_change_set
  142. this_change_set = set()
  143. save(data,"iterator/data_"+str(len(data[1]))+"_1.pk")
  144. if __name__=="__main__":
  145. #train()
  146. test(url = "https://www.600757.com.cn/show-106-14208-1.html")
  147. #val()
  148. #print(2248/2555)
  149. #iteratorLabel()
  150. pass