autoRun_replacelabel.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. #coding:utf8
  2. from commonutil import *
  3. import psycopg2
  4. import sys
  5. entity_circle = [["label_guest_wintenderer","is_wintenderer_label_inference"]]
  6. def save(list):
  7. conn=psycopg2.connect(database="BiddingKM_test_10000",user="postgres",password="postgres",host="192.168.2.101",port="5432")
  8. cur=conn.cursor() #创建指针对象
  9. for table in list:
  10. source = table[0]
  11. result_bak = source+"_bak"
  12. # 创建表
  13. cur.execute(" SELECT to_regclass('"+result_bak+"') is null ")
  14. flag = cur.fetchall()[0][0]
  15. if flag:
  16. cur.execute(" create table "+result_bak+"(entity_id text,label int,rule_id text)")
  17. cur.execute(" insert into "+result_bak+"(entity_id ,label,rule_id) select entity_id,label,rule_id from "+source+" ")
  18. else:
  19. cur.execute(" delete from "+result_bak)
  20. cur.execute(" insert into "+result_bak+"(entity_id,label,rule_id) select entity_id,label,rule_id from "+source+" ")
  21. conn.commit()
  22. conn.close()
  23. def cir(list,expectation=0.7):
  24. conn=psycopg2.connect(database="BiddingKM_test_10000",user="postgres",password="postgres",host="192.168.2.101",port="5432")
  25. cur=conn.cursor() #创建指针对象
  26. for table in list:
  27. source = table[0]
  28. des = table[1]
  29. cur.execute(" delete from "+source)
  30. cur.execute(" insert into "+source+"(entity_id,label,rule_id) select entity_id,case when expectation>"+str(expectation)+" then 1 else 0 end,'circle' from "+des)
  31. conn.commit()
  32. conn.close()
  33. def iterate(list,expectation=0.7):
  34. conn=psycopg2.connect(database="BiddingKM_test_10000",user="postgres",password="postgres",host="192.168.2.101",port="5432")
  35. cur=conn.cursor() #创建指针对象
  36. for table in list:
  37. source = table[0]
  38. des = table[1]
  39. sql = " delete from "+source+" S where exists(select 1 from "+des+" E where S.entity_id=E.entity_id and E.id in (select variable_id from dd_graph_variables_holdout)) "
  40. cur.execute(sql)
  41. sql = " insert into "+source+"(entity_id,label,rule_id) select A.entity_id,case when A.expectation>0.8 then 1 when A.expectation<0.2 then -1 else B.label end,'iterate' from "+des+" A,"+source+"_bak B where A.entity_id=B.entity_id and A.id in (select variable_id from dd_graph_variables_holdout) "
  42. cur.execute(sql)
  43. conn.commit()
  44. conn.close()
  45. if __name__=="__main__":
  46. args = sys.argv
  47. if len(args)>0:
  48. if args[1]=="save":
  49. save(entity_circle)
  50. elif args[1]=="cir":
  51. cir(entity_circle)
  52. elif args[1]=="iterate":
  53. iterate(entity_circle)
  54. else:
  55. pass