predict.py 3.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. @author: bidikeji
  5. @time: 2023/3/27 10:19
  6. """
  7. from BiddingKG.dl.product.product_model import Product_Model
  8. import os
  9. import re
  10. import time
  11. import pandas as pd
  12. import tensorflow as tf
  13. os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
  14. def predict():
  15. ckpt_path = "model"
  16. import json
  17. with tf.Session() as sess:
  18. model = Product_Model()
  19. sess.run(tf.global_variables_initializer())
  20. ckpt = tf.train.get_checkpoint_state(ckpt_path)
  21. # model.saver.restore(sess, os.path.dirname(__file__) + '/model/ner_epoch16_f10.8000_loss1.0775.ckpt')
  22. # model.saver.restore(sess, os.path.dirname(__file__) + '/model/ner_epoch7_f10.7998_loss1.0508.ckpt')
  23. model.saver.restore(sess, os.path.dirname(__file__) + '/model/ner_epoch22_f10.7923_loss1.1039.ckpt') # 整理数据后再次训练
  24. # model.saver.restore(sess, os.path.dirname(__file__) + '/model/ner_epoch18_f10.8000_loss1.1276.ckpt') # 新
  25. # model.saver.restore(sess, os.path.dirname(__file__) + '/model/ner_epoch5_f10.6855_loss1.3800.ckpt') # 旧
  26. t1 = time.time()
  27. print(model.logits, model.lengths, model.trans, model.dropout, model.char_inputs)
  28. # df = pd.read_csv(os.path.dirname(__file__) + '/data/df_test.csv') #../test/
  29. df = pd.read_excel(os.path.dirname(__file__) + '/data/df_test_pred.xlsx')
  30. print('公告数量:', len(df))
  31. df.fillna('', inplace=True)
  32. # df = pd.read_excel('data/所有产品标注数据筛选测试数据2021-12-01_pred.xlsx')
  33. df.reset_index(drop=True, inplace=True)
  34. rs = []
  35. for i in df.index:
  36. text = df.loc[i, 'text']
  37. # result = model.evaluate_line(sess, text)
  38. # print(result[0][1])
  39. # rs.append(json.dumps(result[0][1], ensure_ascii=False))
  40. tmp = []
  41. for line in text.split('。'):
  42. # line = re.sub('[^\w]', '', line)
  43. # if len(line) < 5:
  44. # continue
  45. result = model.evaluate_line(sess, line)
  46. # print(result[0][1])
  47. tmp.extend(result[0][1])
  48. rs.append(json.dumps(tmp, ensure_ascii=False))
  49. df['predict_new'] = pd.Series(rs)
  50. df.to_excel(os.path.dirname(__file__) + '/data/df_test_pred.xlsx', index=False)
  51. print('耗时: ', time.time()-t1)
  52. return df
  53. def 统计准确率(df):
  54. import json
  55. # df = pd.read_excel(os.path.dirname(__file__) + '/data/df_test_pred.xlsx')
  56. df['pr'] = df['predict_new'].apply(lambda x:set([it[0] for it in json.loads(x)]))
  57. df['lb'] = df['lbset'].apply(lambda x: set(json.loads(x)))
  58. df['pos'] = df.apply(lambda x:1 if x['pr']==x['lb'] else 0, axis=1)
  59. eq = lb = pr = 0
  60. for i in df.index:
  61. pred = df.loc[i, 'pr']
  62. label = df.loc[i, 'lb']
  63. lb += len(label)
  64. pr += len(pred)
  65. eq += len(pred&label)
  66. acc = eq/pr
  67. recall = eq/lb
  68. f1 = acc*recall*2/(acc+recall)
  69. print('准确率:%.4f,召回率:%.4f,F1:%.4f'%(acc, recall, f1)) # 准确率:0.6489,召回率:0.8402,F1:0.7323
  70. # df.to_excel(os.path.dirname(__file__) + '/data/df_test_pred.xlsx')
  71. if __name__ == "__main__":
  72. df = predict()
  73. 统计准确率(df)