train.py 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. from myDQN.DQN import DQN
  2. import tensorflow as tf
  3. import numpy as np
  4. import pandas as pd
  5. from entity import *
  6. n_actions = 2
  7. n_features = 3
  8. train_step = 10
  9. max_x = 50
  10. max_y = 50
  11. max_time = 144
  12. gamma = 0.9
  13. def train():
  14. data = load('../train_data/train_data.pkl')
  15. print("数据量:",sum(len(i) for i in data))
  16. step = 0
  17. learn_num = 0
  18. RL = DQN(n_actions,n_features)
  19. for d in data:
  20. for match in d:
  21. s_x = match.driver.x / max_x
  22. s_y = match.driver.y / max_y
  23. s_time = match.order.order_time / max_time
  24. _s_x = match.order.to_x / max_x
  25. _s_y = match.order.to_y / max_y
  26. _s_time = match.order.arrive_time / max_time
  27. travel_time = match.order.travel_time
  28. reward = match.money
  29. # 滴滴论文 reward
  30. reward = (reward/travel_time) * sum([gamma ** i for i in range(travel_time)])
  31. if match.is_cancel:
  32. action = 0
  33. else:
  34. action = 1
  35. RL.store_transition([s_x,s_y,s_time],action,reward,travel_time,[_s_x,_s_y,_s_time])
  36. if (step > 200) and (step % train_step == 0):
  37. RL.learn()
  38. learn_num += 1
  39. step += 1
  40. if learn_num>20000:
  41. break
  42. RL.plot_cost()
  43. print(RL.test())
  44. # test_data = np.array([[1/50,2/50,10/144],
  45. # [25/50,46/50,141/144],
  46. # [45/50,2/50,65/144]])
  47. # print(RL.predict([[16/50,30/50,120/144]],[[25/50,46/50,141/144]],200,21))
  48. # for test in test_data:
  49. # q = RL.predict(test)
  50. # print(q)
  51. if __name__ == '__main__':
  52. train()
  53. pass