train.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. '''
  2. Created on 2019年8月12日
  3. @author: User
  4. '''
  5. from module import model
  6. from module.Utils import *
  7. from keras.callbacks import ModelCheckpoint
  8. from keras import models
  9. import featureEngine
  10. import os
  11. def train():
  12. train_file = "source_12input_padding.pk"
  13. model1 = model.getBiRNNModel(input_shape=[None,12], out_len=2)
  14. data = load(train_file)
  15. train_percent = 0.9
  16. train_percent = 0.8
  17. test_percent=0.9
  18. print(np.shape(data[0]))
  19. train_len = round(len(data[0])*train_percent)
  20. test_len = round(len(data[0])*test_percent)
  21. callback = ModelCheckpoint("log/ep{epoch:03d}-acc{acc:.3f}-loss{loss:.3f}-val_acc{val_acc:.3f}-val_loss{val_loss:.3f}.h5",save_best_only=True, monitor="val_acc", verbose=1, mode="max")
  22. history_model = model1.fit(x=data[0][:train_len],y=data[1][:train_len],validation_data=[data[0][train_len:test_len],data[1][train_len:test_len]],epochs=100,batch_size=48,shuffle=True,callbacks=[callback])
  23. def predict(x):
  24. '''
  25. model1 = model.getBiRNNModel()
  26. model1.load_weights("../model_data/ep133-loss-0.991-val_acc0.972-val_loss-0.951-f10.3121.h5")
  27. '''
  28. path = "log/ep009-acc0.995-loss0.006-val_acc0.986-val_loss0.018.h5"
  29. model1 = models.load_model(path, custom_objects={"acc":acc,"precision":precision,"recall":recall,"f1_score":f1_score,"my_loss":my_loss})
  30. return model1.predict(x,batch_size=1)
  31. def val():
  32. pk_file = "source_12input_padding.pk"
  33. data = load(pk_file)
  34. train_percent = 0.9
  35. train_len = round(len(data[0])*train_percent)
  36. #print(np.shape(data))
  37. predict_y = np.argmax(predict(data[0][train_len:]),1)
  38. label_y = np.argmax(data[1][train_len:],1)
  39. list_url = data[2][train_len:]
  40. size_predict = 0
  41. size_considence = 0
  42. dict_root_true_wrong = dict()
  43. for _predict,_label,_url in zip(predict_y,label_y,list_url):
  44. root = _url.split("/")[2]
  45. if root not in dict_root_true_wrong:
  46. dict_root_true_wrong[root] = [0,0]
  47. if _predict[1]==_label[1]:
  48. size_considence += 1
  49. dict_root_true_wrong[root][0] += 1
  50. else:
  51. dict_root_true_wrong[root][1] += 1
  52. print(_url)
  53. size_predict += 1
  54. list_root_true_wrong = []
  55. for _key in dict_root_true_wrong.keys():
  56. list_root_true_wrong.append([_key,dict_root_true_wrong[_key]])
  57. list_root_true_wrong.sort(key=lambda x:x[1][1]/(x[1][0]+x[1][1]))
  58. print(list_root_true_wrong)
  59. print(size_considence,size_predict)
  60. def test(url):
  61. os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  62. os.environ["CUDA_VISIBLE_DEVICES"] = ""
  63. data = featureEngine.getInput_byJS(url)
  64. if data:
  65. x,list_inner,list_xpath,_ = data
  66. print("x:",x)
  67. p = predict(x)
  68. print(p)
  69. print(np.argmax(p,1))
  70. print(p[0][np.argmax(p,1)[0][1]])
  71. print(list_inner[np.argmax(p,1)[0][1]])
  72. print(list_xpath[np.argmax(p,1)[0][1]])
  73. if __name__=="__main__":
  74. #train()
  75. #val()
  76. test("http://www.gzmodern.cn/html/xydt/announcement/3429.html")