predictor.py 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. import os
  2. from keras import models
  3. import numpy as np
  4. from module.Utils import *
  5. current_path = os.path.dirname(__file__)
  6. class ListpageContentPredictor():
  7. def __init__(self,file=""):
  8. if file=="":
  9. self.model_file = current_path+"/listpage/content/model/ep005-acc0.970-loss0.047-val_acc0.944-val_loss0.077.h5"
  10. else:
  11. self.model_file = current_path+"/listpage/content/model/"+file
  12. self.model = None
  13. self.getModel()
  14. self.graph = tf.get_default_graph()
  15. def getModel(self):
  16. if self.model is None:
  17. self.model = models.load_model(self.model_file, custom_objects={"acc":acc,"precision":precision,"recall":recall,"f1_score":f1_score,"my_loss":my_loss})
  18. self.model.load_weights(self.model_file)
  19. return self.model
  20. def predict(self,x):
  21. with self.graph.as_default():
  22. pre= self.getModel().predict(x)
  23. max_index = np.argmax(pre,1)[0][1]
  24. return max_index
  25. class DetailContentPredictor():
  26. def __init__(self,file=""):
  27. if file=="":
  28. self.model_file = current_path+"/detail/content/model/ep011-loss0.160-val_acc0.900-val_loss0.156-f10.4536.h5"
  29. else:
  30. self.model_file = current_path+"/detail/content/model/"+file
  31. self.model = None
  32. self.getModel()
  33. self.graph = tf.get_default_graph()
  34. def getModel(self):
  35. if self.model is None:
  36. self.model = models.load_model(self.model_file, custom_objects={"acc":acc,"precision":precision,"recall":recall,"f1_score":f1_score,"my_loss":my_loss})
  37. self.model.load_weights(self.model_file)
  38. return self.model
  39. def predict(self,x):
  40. with self.graph.as_default():
  41. pre= self.getModel().predict(x)
  42. max_index = np.argmax(pre,1)[0][1]
  43. return max_index
  44. class DetailTitlePredictor():
  45. def __init__(self,file=""):
  46. if file=="":
  47. self.model_file = current_path+"/detail/title/model/ep009-acc0.995-loss0.006-val_acc0.986-val_loss0.018.h5"
  48. else:
  49. self.model_file = current_path+"/detail/title/model/"+file
  50. self.model = None
  51. self.getModel()
  52. self.graph = tf.get_default_graph()
  53. def getModel(self):
  54. if self.model is None:
  55. self.model = models.load_model(self.model_file, custom_objects={"acc":acc,"precision":precision,"recall":recall,"f1_score":f1_score,"my_loss":my_loss})
  56. self.model.load_weights(self.model_file)
  57. return self.model
  58. def predict(self,x):
  59. with self.graph.as_default():
  60. pre= self.getModel().predict(x)
  61. max_index = np.argmax(pre,1)[0][1]
  62. return max_index