main.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. #!/usr/bin/python3
  2. # -*- coding: utf-8 -*-
  3. # @Author : bidikeji
  4. # @Time : 2021/1/13 0013 14:03
  5. from BiddingKG.dl.product.product_model import Product_Model
  6. from BiddingKG.dl.product.data_util import BatchManager, get_label_data, id_to_tag, input_from_line, decode, result_to_json, df2data,dfsearchlb
  7. from BiddingKG.dl.product.data_process import data_precess
  8. import numpy as np
  9. import pandas as pd
  10. import tensorflow as tf
  11. import random
  12. import pickle
  13. import os
  14. import glob
  15. os.environ['CUDA_VISIBLE_DEVICES'] = "-1"
  16. def train():
  17. # all_data = get_label_data()
  18. # random.shuffle(all_data)
  19. # train_data = all_data[:int(len(all_data)*0.85)]
  20. # dev_data = all_data[int(len(all_data)*0.85):]
  21. # df = pd.read_excel('data/所有产品标注数据筛选20211125.xlsx')
  22. # df.reset_index(drop=True, inplace=True)
  23. # np.random.seed(8)
  24. # shuffle_ids = np.random.permutation(len(df))
  25. # split_ids = int(len(df)*0.1)
  26. # train_ids = shuffle_ids[split_ids:]
  27. # dev_ids = shuffle_ids[:int(split_ids/2)]
  28. # df_train = df.iloc[train_ids]
  29. # df_dev = df.iloc[dev_ids]
  30. # train_data = df2data(df_train)
  31. # dev_data = df2data(df_dev)
  32. # with open(os.path.dirname(__file__)+'/data/train_data2021-11-30.pkl', 'rb') as f:
  33. # train_data = pickle.load(f)
  34. # with open(os.path.dirname(__file__)+'data/dev_data2021-11-30.pkl', 'rb') as f:
  35. # dev_data = pickle.load(f)
  36. train_data, dev_data = data_precess()
  37. train_manager = BatchManager(train_data, batch_size=256)
  38. dev_manager = BatchManager(dev_data, batch_size=256)
  39. # tf_config = tf.ConfigProto()
  40. # tf_config.gpu_options.allow_growth = True
  41. tf_config = tf.ConfigProto(device_count={'gpu': 1})
  42. steps_per_epoch = train_manager.len_data
  43. ckpt_path = os.path.dirname(__file__)+'/'+"model"
  44. with tf.Session(config=tf_config) as sess:
  45. model = Product_Model()
  46. sess.run(tf.global_variables_initializer())
  47. model.saver.restore(sess, os.path.join(ckpt_path, "ner2.ckpt"))
  48. # ckpt = tf.train.get_checkpoint_state(ckpt_path)
  49. # if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
  50. # model.saver.restore(sess, ckpt.model_checkpoint_path)
  51. # print("从文件加载原来模型数据",ckpt.model_checkpoint_path)
  52. print('准备训练数据')
  53. loss = []
  54. mix_loss = 1000
  55. max_f1 = 0
  56. for i in range(20):
  57. print('epochs:',i)
  58. # model.evaluate(sess, data_manager=dev_manager, id_to_tag=id_to_tag)
  59. # break
  60. for batch in train_manager.iter_batch(shuffle=True):
  61. # print('batch:',len(batch))
  62. # step, batch_loss = model.run_step(sess, True, batch)
  63. step, batch_loss = model.run_step(sess, 'train', batch)
  64. loss.append(batch_loss)
  65. if step % 1000 == 0:
  66. iteration = step // steps_per_epoch + 1
  67. print('iter:{} step:{} loss:{}'.format(iteration, step, np.mean(loss)))
  68. if i >= 2 or i%5==0:
  69. f1, precision, recall, evl_loss = model.evaluate(sess, data_manager=dev_manager, id_to_tag=id_to_tag)
  70. print('f1:%.4f, precision:%.4f, recall:%.4f, evl_loss:%.4f' % (f1, precision, recall, evl_loss))
  71. # if max_f1 < f1:
  72. # model.saver.save(sess, os.path.join(ckpt_path, "ner2.ckpt"))
  73. # print("model save .bast f1 is %.4f" % f1)
  74. # max_f1 = f1
  75. if evl_loss<mix_loss and max_f1 < f1:
  76. mix_loss = evl_loss
  77. max_f1 = f1
  78. model.saver.save(sess, os.path.join(ckpt_path, "ner1202_find_lb.ckpt")) #ner1130_find_lb.ckpt
  79. print("model saved, val_loss is:",mix_loss)
  80. loss = []
  81. def evaluate_line():
  82. ckpt_path = "model"
  83. with tf.Session() as sess:
  84. model = Product_Model()
  85. sess.run(tf.global_variables_initializer())
  86. # model.saver.restore(sess, 'model/ner1215.ckpt')
  87. # model.saver.restore(sess, 'model/ner_f10.7039_loss1.2353.ckpt')
  88. model.saver.restore(sess, 'model/ner_epoch5_f10.6855_loss1.3800.ckpt')
  89. while True:
  90. line = input("请输入测试句子:")
  91. result = model.evaluate_line(sess, line)
  92. print(result)
  93. # ckpt = tf.train.get_checkpoint_state(ckpt_path)
  94. # if ckpt and tf.train.checkpoint_exists(ckpt_path):
  95. # print('模型文件:',ckpt.model_checkpoint_path)
  96. # model.saver.restore(sess, ckpt.model_checkpoint_path)
  97. # print(model.logits, model.lengths, model.trans, model.dropout, model.char_inputs)
  98. # while True:
  99. # line = input("请输入测试句子:")
  100. # result = model.evaluate_line(sess, line)
  101. # print(result)
  102. def save_model_pb():
  103. from tensorflow.python.framework import graph_util
  104. model_folder = r"D:\Bidi\BIDI_ML_INFO_EXTRACTION\BiddingKG\dl\product\model"
  105. output_graph = r"D:\Bidi\BIDI_ML_INFO_EXTRACTION\BiddingKG\dl\product\model\product.pb"
  106. #
  107. # 把cpkt转为pb
  108. # input_checkpoint = "model/ner_epoch5_f10.6855_loss1.3800.ckpt"
  109. input_checkpoint = "model/ner_epoch22_f10.7923_loss1.1039.ckpt" #2023/4/6
  110. saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
  111. graph = tf.get_default_graph() # 获得默认的图
  112. input_graph_def = graph.as_graph_def() # 返回一个序列号
  113. with tf.Session() as sess:
  114. saver.restore(sess, input_checkpoint) # 恢复图并获得数据
  115. output_graph_def = graph_util.convert_variables_to_constants(
  116. sess=sess,
  117. input_graph_def=input_graph_def,
  118. output_node_names=["CharInputs", "Sum", "Dropout", "logits/Reshape", "crf_loss/transitions"]
  119. )
  120. with tf.gfile.GFile("model/productAndfailreason.pb", "wb") as f:
  121. f.write(output_graph_def.SerializeToString())
  122. print("%d ops in the final graph" % len(output_graph_def.node))
  123. #
  124. # graph = tf.get_default_graph() # 获得默认的图
  125. # input_graph_def = graph.as_graph_def() # 返回一个序列号
  126. # with tf.Session() as sess:
  127. # model = Product_Model()
  128. # sess.run(tf.global_variables_initializer())
  129. # model.saver.restore(sess, 'model/ner_epoch5_f10.6855_loss1.3800.ckpt') # 恢复图并获得数据
  130. # output_graph_def = graph_util.convert_variables_to_constants(
  131. # sess=sess,
  132. # input_graph_def=input_graph_def,
  133. # output_node_names=["CharInputs", "Sum", "Dropout", "logits/Reshape", "crf_loss/transitions"]
  134. # )
  135. # with tf.gfile.GFile("model/productAndfailreason.pb", "wb") as f:
  136. # f.write(output_graph_def.SerializeToString())
  137. # with tf.Session() as sess:
  138. # model = Product_Model()
  139. # sess.run(tf.global_variables_initializer())
  140. # model.saver.restore(sess, 'model/ner_epoch5_f10.6855_loss1.3800.ckpt')
  141. # tf.saved_model.simple_save(sess, 'model/productAndfailreason',
  142. # inputs={
  143. # "CharInputs":model.char_inputs,
  144. # "Dropout":model.dropout,
  145. # },
  146. # outputs={
  147. # "Sum:": model.lengths,
  148. # "logits/Reshape": model.logits,
  149. # "crf_loss/transitions": model.trans
  150. # })
  151. print('保存pb文件')
  152. def predict():
  153. # pb_path = "model/product.pb"
  154. pb_path = "model/productAndfailreason.pb"
  155. with tf.Graph().as_default():
  156. output_graph_def = tf.GraphDef()
  157. with open(pb_path, 'rb') as f:
  158. output_graph_def.ParseFromString(f.read())
  159. tf.import_graph_def(output_graph_def, name='') # 注意这里不能加名字
  160. with tf.Session() as sess:
  161. sess.run(tf.global_variables_initializer())
  162. for node in output_graph_def.node:
  163. print(node.name)
  164. char_input = sess.graph.get_tensor_by_name("CharInputs:0")
  165. length = sess.graph.get_tensor_by_name("Sum:0")
  166. dropout = sess.graph.get_tensor_by_name("Dropout:0")
  167. logit = sess.graph.get_tensor_by_name("logits/Reshape:0")
  168. tran = sess.graph.get_tensor_by_name("crf_loss/transitions:0")
  169. while True:
  170. line = input("请输入测试句子:")
  171. _, chars, tags = input_from_line(line)
  172. print(chars)
  173. lengths, scores, tran_ = sess.run([length,logit,tran],feed_dict={char_input:np.asarray(chars),
  174. dropout:1.0
  175. } )
  176. batch_paths = decode(scores, lengths, tran_)
  177. tags = batch_paths[0] # batch_paths[0][:lengths] 错误
  178. result = result_to_json(line, tags)
  179. print(result)
  180. def predict_df():
  181. ckpt_path = "model"
  182. import json
  183. with tf.Session() as sess:
  184. model = Product_Model()
  185. sess.run(tf.global_variables_initializer())
  186. ckpt = tf.train.get_checkpoint_state(ckpt_path)
  187. # model.saver.restore(sess, 'model/ner2.ckpt')
  188. # model.saver.restore(sess, 'model/ner1201_find_lb.ckpt') # f1:0.6972, precision:0.7403, recall:0.6588, evl_loss:1.2983 model saved, val_loss is: 1.32706
  189. # model.saver.restore(sess, 'model/ner1208_find_lb.ckpt') # f1:0.7038, precision:0.7634, recall:0.6528, evl_loss:1.3046 model saved, val_loss is: 1.29316
  190. # model.saver.restore(sess, 'model/ner_f10.7039_loss1.2353.ckpt') # f1:0.70 ner1215
  191. model.saver.restore(sess, 'model/ner_epoch5_f10.6855_loss1.3800.ckpt') # f1:0.70 ner1215
  192. print(model.logits, model.lengths, model.trans, model.dropout, model.char_inputs)
  193. # df = pd.read_excel('../test/data/贵州数据新字段提取信息_predict.xlsx')
  194. # df = pd.read_excel('../test/data/所有产品标注数据_补充筛选废标原因数据.xlsx')
  195. # df = pd.read_excel('../test/data/所有产品标注数据筛选_废标_predict.xlsx')
  196. # df = pd.read_excel('data/所有产品标注数据筛选20211125_ProductAndReason.xlsx') #../test/
  197. df = pd.read_excel('data/产品数据自己人标注的原始数据_pred.xlsx') #../test/
  198. df.fillna('', inplace=True)
  199. # df = pd.read_excel('data/所有产品标注数据筛选测试数据2021-12-01_pred.xlsx')
  200. df.reset_index(drop=True, inplace=True)
  201. rs = []
  202. for i in df.index:
  203. line = df.loc[i, 'text']
  204. # pos = df.loc[i, 'feibiao']
  205. # reason = df.loc[i, 'reasons_label']
  206. # if pos==0 and reason=='[]':
  207. # rs.append('')
  208. # continue
  209. # if i > 200:
  210. # rs.append('')
  211. # continue
  212. # line = df.loc[i, 'process_text']
  213. result = model.evaluate_line(sess, line)
  214. print(result[0][1])
  215. rs.append(json.dumps(result[0][1], ensure_ascii=False))
  216. # df['pred_new1202'] = pd.Series(rs)
  217. # df['reson_model2'] = pd.Series(rs)
  218. df['product_pred'] = pd.Series(rs)
  219. # df.to_excel('../test/data/贵州数据新字段提取信息_predict.xlsx')
  220. # df.to_excel('../test/data/所有产品标注数据_补充筛选废标原因数据_predict.xlsx')
  221. # df.to_excel('../test/data/所有产品标注数据筛选_废标_predict.xlsx')
  222. # df.to_excel('../test/data/所有产品标注数据筛选20211125_ProductAndReason.xlsx')
  223. df.to_excel('data/产品数据自己人标注的原始数据_pred.xlsx')
  224. # df.to_excel('data/所有产品标注数据筛选测试数据2021-12-01_pred.xlsx')
  225. if __name__ == "__main__":
  226. # train()
  227. # evaluate_line()
  228. # save_model_pb()
  229. predict()
  230. # predict_df()
  231. # import json
  232. # df = pd.read_excel('data/所有产品标注数据筛选测试数据2021-12-01_pred.xlsx')
  233. # old_new = []
  234. # new_old = []
  235. # df['old-new'] = df.apply(lambda x:set([str(it) for it in json.loads(x['pred_old'])])-set([str(it) for it in json.loads(x['pred_new'])]), axis=1)
  236. # df['new-old'] = df.apply(lambda x:set([str(it) for it in json.loads(x['pred_new'])])-set([str(it) for it in json.loads(x['pred_old'])]), axis=1)
  237. # df['old=new'] = df.apply(lambda x: 1 if x['old-new']==x['new-old'] else 0, axis=1)
  238. # df.to_excel('data/所有产品标注数据筛选测试数据2021-12-01_pred.xlsx')
  239. # with open('data/dev_data2.pkl', 'rb') as f:
  240. # dev_data = pickle.load(f)
  241. # import json
  242. # df_dev = pd.read_excel('data/产品数据自己人标注的原始数据.xlsx')[:]
  243. # def rows2lb(rows):
  244. # rows = json.loads(rows)
  245. # rows = list(set([it[0].split()[-1] for it in rows]))
  246. # return json.dumps(rows, ensure_ascii=False)
  247. # df_dev['lbset'] = df_dev['rows'].apply(lambda x:rows2lb(x))
  248. # dev_data = dfsearchlb(df_dev)
  249. # dev_manager = BatchManager(dev_data, batch_size=64)
  250. # # ckpt_path = "model/ner0305.ckpt" #f1:0.7304, precision:0.8092, recall:0.6656, evl_loss:2.2160
  251. # # ckpt_path = "model/ner0316.ckpt" #f1:0.7220, precision:0.7854, recall:0.6681, evl_loss:2.2921
  252. # # ckpt_path = "model/ner2.ckpt" # f1:0.8019, precision:0.8541, recall:0.7557, evl_loss:1.6286
  253. # # ckpt_path = "model/ner1029.ckpt" #f1:0.6374, precision:0.6897, recall:0.5924, evl_loss:2.0840
  254. # # ckpt_path = "model/ner1129.ckpt" #f1:0.6034, precision:0.6931, recall:0.5343, evl_loss:1.9704
  255. # ckpt_path = "model/ner1129.ckpt" #f1:0.6034, precision:0.6931, recall:0.5343, evl_loss:1.9704
  256. # with tf.Session() as sess:
  257. # model = Product_Model()
  258. # sess.run(tf.global_variables_initializer())
  259. # model.saver.restore(sess, ckpt_path)
  260. # print("从文件加载原来模型数据",ckpt_path)
  261. # f1, precision, recall, evl_loss = model.evaluate(sess, data_manager=dev_manager, id_to_tag=id_to_tag)
  262. # print('f1:%.4f, precision:%.4f, recall:%.4f, evl_loss:%.4f' % (f1, precision, recall, evl_loss))