main.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. # encoding=utf-8
  2. import os
  3. #os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  4. #os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # 指定使用CPU运行
  5. import pickle
  6. import pandas as pd
  7. import tensorflow as tf
  8. from text_classifier_pai.data_util import precision, recall, f1_score, get_remove_word, get_embedding, get_label, get_dic, clean_word_with_tokenizer, model_in
  9. # from data_util import precision, recall, f1_score, get_remove_word, get_embedding, get_label, get_dic, clean_word_with_tokenizer, model_in
  10. import keras.backend as K
  11. from keras.layers import Input, Embedding, Bidirectional, GRU, Dropout, Dense, Concatenate,Lambda,LSTM
  12. from keras.models import Model
  13. # from keras import models, metrics
  14. from keras.callbacks import ModelCheckpoint
  15. from keras.engine.topology import Layer
  16. from keras.optimizers import Adam,SGD
  17. class Attention(Layer):
  18. def __init__(self, **kwargs):
  19. super(Attention, self).__init__(**kwargs)
  20. def build(self, input_shape):
  21. # W: (EMBED_SIZE, 1)
  22. # b: (MAX_TIMESTEPS, 1)
  23. # u: (MAX_TIMESTEPS, MAX_TIMESTEPS)
  24. self.W = self.add_weight(name="W_{:s}".format(self.name),
  25. shape=(input_shape[-1], 1),
  26. initializer="normal")
  27. self.b = self.add_weight(name="b_{:s}".format(self.name),
  28. shape=(input_shape[1], 1),
  29. initializer="zeros")
  30. self.u = self.add_weight(name="u_{:s}".format(self.name),
  31. shape=(input_shape[1], input_shape[1]),
  32. initializer="normal")
  33. super(Attention, self).build(input_shape)
  34. def call(self, x, mask=None):
  35. # input: (BATCH_SIZE, MAX_TIMESTEPS, EMBED_SIZE)
  36. # et: (BATCH_SIZE, MAX_TIMESTEPS)
  37. et = K.squeeze(K.tanh(K.dot(x, self.W) + self.b), axis=-1)
  38. # at: (BATCH_SIZE, MAX_TIMESTEPS)
  39. at = K.dot(et, self.u)
  40. at = K.exp(at)
  41. if mask is not None:
  42. at *= K.cast(mask, K.floatx())
  43. # ot: (BATCH_SIZE, MAX_TIMESTEPS, EMBED_SIZE)
  44. at /= K.cast(K.sum(at, axis=1, keepdims=True) + K.epsilon(), K.floatx())
  45. atx = K.expand_dims(at, axis=-1)
  46. ot = atx * x
  47. # output: (BATCH_SIZE, EMBED_SIZE)
  48. return K.sum(ot, axis=1)
  49. def compute_mask(self, input, input_mask=None):
  50. # do not pass the mask to the next layers
  51. return None
  52. def compute_output_shape(self, input_shape):
  53. # output shape: (BATCH_SIZE, EMBED_SIZE)
  54. return (input_shape[0], input_shape[-1])
  55. def get_config(self):
  56. return super(Attention, self).get_config()
  57. class Text_Classifier():
  58. def __init__(self):
  59. self.remove_word = get_remove_word() # 加载停用词、不重要的词
  60. self.word_index, self.tokenizer, self.embedding_matrix = get_embedding() # 加载文件,返回词典、keras tokennizer对象,词向量矩阵
  61. self.label_mapping, self.labels = get_label() # 加载标签字典,返回字典label_mapping {0: '安防系统', 1: '安全保护服务', 2: '安全保护设备' ; labels10 所有类别的中文名称
  62. self.dic_label = get_dic() # 加载分类 大类中类
  63. # self.checkpoint = '/home/python/lishimin/linuxPro/text_classifier_project/model/New_attentionGUR_embed100_relabel0311.h5'
  64. self.graph = tf.get_default_graph()
  65. self.model = self.bigru_attention_softmax(150, self.word_index, self.embedding_matrix, classes=203)
  66. # self.model.load_weights(self.checkpoint)
  67. self.model.load_weights(os.path.dirname(__file__)+'/pickle_1/AttentionGRUacc0.9_class203.model')
  68. def bigru_attention_softmax(self,input_size, word_index, embedding_matrix, classes):
  69. sent_inputs = Input(shape=(input_size,), dtype="float32")
  70. sent_emb = Embedding(input_dim=len(word_index) + 1,
  71. output_dim=100,
  72. mask_zero=True,
  73. weights=[embedding_matrix])(sent_inputs)
  74. sent_enc = Bidirectional(GRU(512,dropout=0.5, recurrent_dropout=0.5,
  75. return_sequences=True))(sent_emb)
  76. embeddings = Dropout(0.5)(sent_enc)
  77. sent_att1 = Attention()(embeddings)
  78. fc2_dropout = Dropout(0.5)(sent_att1)
  79. # fc1 = Dense(1024, activation="relu")(fc1_dropout)
  80. # fc2_dropout = Dropout(0.5)(fc1)
  81. sent_pred = Dense(classes, activation="softmax")(fc2_dropout)
  82. model = Model(inputs=sent_inputs, outputs=sent_pred)
  83. # model.summary()
  84. return model
  85. def process(self,text_list):
  86. ContentIDs = [[i, text] for i, text in enumerate(text_list)]
  87. features, ids = clean_word_with_tokenizer(ContentIDs, self.remove_word, self.tokenizer)
  88. return features, ids
  89. def predict(self, features, ids):
  90. with self.graph.as_default():
  91. logits = self.model.predict(features)
  92. return logits, ids
  93. def get_results(self, logits, ids):
  94. return model_in(logits, self.label_mapping, ids)
  95. if __name__ == '__main__':
  96. file = '/data/python/lsm/test_11_relabel_0304.csv' # 20200304重新标注的数据
  97. # file = '/home/python/projects_deeplearning/TextSplit/test_11.csv' # 耔录原来标注数据
  98. df = pd.read_csv(file)
  99. text_list = list(df['file'])
  100. classifier = Text_Classifier()
  101. features, ids = classifier.process([text_list[843]])
  102. logits, ids = classifier.predict(features, ids)
  103. results = classifier.get_results(logits, ids)
  104. print(results)