main.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. #!/usr/bin/python3
  2. # -*- coding: utf-8 -*-
  3. # @Author : bidikeji
  4. # @Time : 2021/1/13 0013 14:03
  5. from BiddingKG.dl.product.product_model import Product_Model
  6. from BiddingKG.dl.product.data_util import BatchManager, get_label_data, id_to_tag, input_from_line, decode, result_to_json, df2data,dfsearchlb
  7. from BiddingKG.dl.product.data_process import data_precess
  8. import numpy as np
  9. import pandas as pd
  10. import tensorflow as tf
  11. import random
  12. import pickle
  13. import os
  14. import glob
  15. os.environ['CUDA_VISIBLE_DEVICES'] = "-1"
  16. def train():
  17. # all_data = get_label_data()
  18. # random.shuffle(all_data)
  19. # train_data = all_data[:int(len(all_data)*0.85)]
  20. # dev_data = all_data[int(len(all_data)*0.85):]
  21. # df = pd.read_excel('data/所有产品标注数据筛选20211125.xlsx')
  22. # df.reset_index(drop=True, inplace=True)
  23. # np.random.seed(8)
  24. # shuffle_ids = np.random.permutation(len(df))
  25. # split_ids = int(len(df)*0.1)
  26. # train_ids = shuffle_ids[split_ids:]
  27. # dev_ids = shuffle_ids[:int(split_ids/2)]
  28. # df_train = df.iloc[train_ids]
  29. # df_dev = df.iloc[dev_ids]
  30. # train_data = df2data(df_train)
  31. # dev_data = df2data(df_dev)
  32. # with open(os.path.dirname(__file__)+'/data/train_data2021-11-30.pkl', 'rb') as f:
  33. # train_data = pickle.load(f)
  34. # with open(os.path.dirname(__file__)+'data/dev_data2021-11-30.pkl', 'rb') as f:
  35. # dev_data = pickle.load(f)
  36. train_data, dev_data = data_precess()
  37. train_manager = BatchManager(train_data, batch_size=256)
  38. dev_manager = BatchManager(dev_data, batch_size=256)
  39. # tf_config = tf.ConfigProto()
  40. # tf_config.gpu_options.allow_growth = True
  41. tf_config = tf.ConfigProto(device_count={'gpu': 1})
  42. steps_per_epoch = train_manager.len_data
  43. ckpt_path = os.path.dirname(__file__)+'/'+"model"
  44. with tf.Session(config=tf_config) as sess:
  45. model = Product_Model()
  46. sess.run(tf.global_variables_initializer())
  47. model.saver.restore(sess, os.path.join(ckpt_path, "ner2.ckpt"))
  48. # ckpt = tf.train.get_checkpoint_state(ckpt_path)
  49. # if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
  50. # model.saver.restore(sess, ckpt.model_checkpoint_path)
  51. # print("从文件加载原来模型数据",ckpt.model_checkpoint_path)
  52. print('准备训练数据')
  53. loss = []
  54. mix_loss = 1000
  55. max_f1 = 0
  56. for i in range(20):
  57. print('epochs:',i)
  58. # model.evaluate(sess, data_manager=dev_manager, id_to_tag=id_to_tag)
  59. # break
  60. for batch in train_manager.iter_batch(shuffle=True):
  61. # print('batch:',len(batch))
  62. # step, batch_loss = model.run_step(sess, True, batch)
  63. step, batch_loss = model.run_step(sess, 'train', batch)
  64. loss.append(batch_loss)
  65. if step % 1000 == 0:
  66. iteration = step // steps_per_epoch + 1
  67. print('iter:{} step:{} loss:{}'.format(iteration, step, np.mean(loss)))
  68. if i >= 2 or i%5==0:
  69. f1, precision, recall, evl_loss = model.evaluate(sess, data_manager=dev_manager, id_to_tag=id_to_tag)
  70. print('f1:%.4f, precision:%.4f, recall:%.4f, evl_loss:%.4f' % (f1, precision, recall, evl_loss))
  71. # if max_f1 < f1:
  72. # model.saver.save(sess, os.path.join(ckpt_path, "ner2.ckpt"))
  73. # print("model save .bast f1 is %.4f" % f1)
  74. # max_f1 = f1
  75. if evl_loss<mix_loss and max_f1 < f1:
  76. mix_loss = evl_loss
  77. max_f1 = f1
  78. model.saver.save(sess, os.path.join(ckpt_path, "ner1202_find_lb.ckpt")) #ner1130_find_lb.ckpt
  79. print("model saved, val_loss is:",mix_loss)
  80. loss = []
  81. def evaluate_line():
  82. ckpt_path = "model"
  83. with tf.Session() as sess:
  84. model = Product_Model()
  85. sess.run(tf.global_variables_initializer())
  86. # model.saver.restore(sess, 'model/ner1215.ckpt')
  87. # model.saver.restore(sess, 'model/ner_f10.7039_loss1.2353.ckpt')
  88. model.saver.restore(sess, 'model/ner_epoch10_f10.6875_loss1.5230.ckpt')
  89. while True:
  90. line = input("请输入测试句子:")
  91. result = model.evaluate_line(sess, line)
  92. print(result)
  93. # ckpt = tf.train.get_checkpoint_state(ckpt_path)
  94. # if ckpt and tf.train.checkpoint_exists(ckpt_path):
  95. # print('模型文件:',ckpt.model_checkpoint_path)
  96. # model.saver.restore(sess, ckpt.model_checkpoint_path)
  97. # print(model.logits, model.lengths, model.trans, model.dropout, model.char_inputs)
  98. # while True:
  99. # line = input("请输入测试句子:")
  100. # result = model.evaluate_line(sess, line)
  101. # print(result)
  102. def predict():
  103. pb_path = "model/product.pb"
  104. with tf.Graph().as_default():
  105. output_graph_def = tf.GraphDef()
  106. with open(pb_path, 'rb') as f:
  107. output_graph_def.ParseFromString(f.read())
  108. tf.import_graph_def(output_graph_def, name='') # 注意这里不能加名字
  109. with tf.Session() as sess:
  110. sess.run(tf.global_variables_initializer())
  111. for node in output_graph_def.node:
  112. print(node.name)
  113. char_input = sess.graph.get_tensor_by_name("CharInputs:0")
  114. length = sess.graph.get_tensor_by_name("Sum:0")
  115. dropout = sess.graph.get_tensor_by_name("Dropout:0")
  116. logit = sess.graph.get_tensor_by_name("logits/Reshape:0")
  117. tran = sess.graph.get_tensor_by_name("crf_loss/transitions:0")
  118. while True:
  119. line = input("请输入测试句子:")
  120. _, chars, tags = input_from_line(line)
  121. print(chars)
  122. lengths, scores, tran_ = sess.run([length,logit,tran],feed_dict={char_input:np.asarray(chars),
  123. dropout:1.0
  124. } )
  125. batch_paths = decode(scores, lengths, tran_)
  126. tags = batch_paths[0] # batch_paths[0][:lengths] 错误
  127. result = result_to_json(line, tags)
  128. print(result)
  129. def predict_df():
  130. ckpt_path = "model"
  131. import json
  132. with tf.Session() as sess:
  133. model = Product_Model()
  134. sess.run(tf.global_variables_initializer())
  135. ckpt = tf.train.get_checkpoint_state(ckpt_path)
  136. # model.saver.restore(sess, 'model/ner2.ckpt')
  137. # model.saver.restore(sess, 'model/ner1201_find_lb.ckpt') # f1:0.6972, precision:0.7403, recall:0.6588, evl_loss:1.2983 model saved, val_loss is: 1.32706
  138. # model.saver.restore(sess, 'model/ner1208_find_lb.ckpt') # f1:0.7038, precision:0.7634, recall:0.6528, evl_loss:1.3046 model saved, val_loss is: 1.29316
  139. # model.saver.restore(sess, 'model/ner_f10.7039_loss1.2353.ckpt') # f1:0.70 ner1215
  140. model.saver.restore(sess, 'model/ner_epoch4_f10.6952_loss1.2512.ckpt') # f1:0.70 ner1215
  141. print(model.logits, model.lengths, model.trans, model.dropout, model.char_inputs)
  142. # df = pd.read_excel('../test/data/贵州数据新字段提取信息_predict.xlsx')
  143. # df = pd.read_excel('../test/data/所有产品标注数据_补充筛选废标原因数据.xlsx')
  144. # df = pd.read_excel('../test/data/所有产品标注数据筛选_废标_predict.xlsx')
  145. df = pd.read_excel('../test/data/所有产品标注数据筛选20211125_ProductAndReason.xlsx')
  146. # df = pd.read_excel('data/所有产品标注数据筛选测试数据2021-12-01_pred.xlsx')
  147. df.reset_index(drop=True, inplace=True)
  148. rs = []
  149. for i in df.index:
  150. line = df.loc[i, 'text']
  151. pos = df.loc[i, 'pos']
  152. reason = df.loc[i, 'reasons_label']
  153. if pos==0 or reason!='[]':
  154. rs.append('')
  155. continue
  156. # if i > 200:
  157. # rs.append('')
  158. # continue
  159. # line = df.loc[i, 'process_text']
  160. result = model.evaluate_line(sess, line)
  161. print(result[0][1])
  162. rs.append(json.dumps(result[0][1], ensure_ascii=False))
  163. # df['pred_new1202'] = pd.Series(rs)
  164. df['reson_model'] = pd.Series(rs)
  165. # df.to_excel('../test/data/贵州数据新字段提取信息_predict.xlsx')
  166. # df.to_excel('../test/data/所有产品标注数据_补充筛选废标原因数据_predict.xlsx')
  167. # df.to_excel('../test/data/所有产品标注数据筛选_废标_predict.xlsx')
  168. df.to_excel('../test/data/所有产品标注数据筛选20211125_ProductAndReason.xlsx')
  169. # df.to_excel('data/所有产品标注数据筛选测试数据2021-12-01_pred.xlsx')
  170. if __name__ == "__main__":
  171. # train()
  172. evaluate_line()
  173. # predict()
  174. # predict_df()
  175. # import json
  176. # df = pd.read_excel('data/所有产品标注数据筛选测试数据2021-12-01_pred.xlsx')
  177. # old_new = []
  178. # new_old = []
  179. # df['old-new'] = df.apply(lambda x:set([str(it) for it in json.loads(x['pred_old'])])-set([str(it) for it in json.loads(x['pred_new'])]), axis=1)
  180. # df['new-old'] = df.apply(lambda x:set([str(it) for it in json.loads(x['pred_new'])])-set([str(it) for it in json.loads(x['pred_old'])]), axis=1)
  181. # df['old=new'] = df.apply(lambda x: 1 if x['old-new']==x['new-old'] else 0, axis=1)
  182. # df.to_excel('data/所有产品标注数据筛选测试数据2021-12-01_pred.xlsx')
  183. # with open('data/dev_data2.pkl', 'rb') as f:
  184. # dev_data = pickle.load(f)
  185. # import json
  186. # df_dev = pd.read_excel('data/产品数据自己人标注的原始数据.xlsx')[:]
  187. # def rows2lb(rows):
  188. # rows = json.loads(rows)
  189. # rows = list(set([it[0].split()[-1] for it in rows]))
  190. # return json.dumps(rows, ensure_ascii=False)
  191. # df_dev['lbset'] = df_dev['rows'].apply(lambda x:rows2lb(x))
  192. # dev_data = dfsearchlb(df_dev)
  193. # dev_manager = BatchManager(dev_data, batch_size=64)
  194. # # ckpt_path = "model/ner0305.ckpt" #f1:0.7304, precision:0.8092, recall:0.6656, evl_loss:2.2160
  195. # # ckpt_path = "model/ner0316.ckpt" #f1:0.7220, precision:0.7854, recall:0.6681, evl_loss:2.2921
  196. # # ckpt_path = "model/ner2.ckpt" # f1:0.8019, precision:0.8541, recall:0.7557, evl_loss:1.6286
  197. # # ckpt_path = "model/ner1029.ckpt" #f1:0.6374, precision:0.6897, recall:0.5924, evl_loss:2.0840
  198. # # ckpt_path = "model/ner1129.ckpt" #f1:0.6034, precision:0.6931, recall:0.5343, evl_loss:1.9704
  199. # ckpt_path = "model/ner1129.ckpt" #f1:0.6034, precision:0.6931, recall:0.5343, evl_loss:1.9704
  200. # with tf.Session() as sess:
  201. # model = Product_Model()
  202. # sess.run(tf.global_variables_initializer())
  203. # model.saver.restore(sess, ckpt_path)
  204. # print("从文件加载原来模型数据",ckpt_path)
  205. # f1, precision, recall, evl_loss = model.evaluate(sess, data_manager=dev_manager, id_to_tag=id_to_tag)
  206. # print('f1:%.4f, precision:%.4f, recall:%.4f, evl_loss:%.4f' % (f1, precision, recall, evl_loss))