train.py 1013 B

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. from myDQN.DQN import DQN
  2. import tensorflow as tf
  3. import numpy as np
  4. import pandas as pd
  5. from entity import *
  6. n_actions = 2
  7. n_features = 3
  8. max_x = 50
  9. max_y = 50
  10. max_time = 144
  11. def train():
  12. data = load('../train_data/train_data.pkl')
  13. print("数据量:",sum(len(i) for i in data))
  14. step = 0
  15. RL = DQN(n_actions,n_features)
  16. for d in data:
  17. for match in d:
  18. s_x = match.driver.x / max_x
  19. s_y = match.driver.y / max_y
  20. s_time = match.order.order_time / max_time
  21. _s_x = match.order.to_x / max_x
  22. _s_y = match.order.to_y / max_y
  23. _s_time = match.order.arrive_time / max_time
  24. travel_time = match.order.travel_time
  25. reward = match.money
  26. RL.store_transition((s_x,s_y,s_time),0,reward,travel_time,(_s_x,_s_y,_s_time))
  27. if (step > 200) and (step % 10 == 0):
  28. RL.learn()
  29. step += 1
  30. RL.plot_cost()
  31. if __name__ == '__main__':
  32. train()
  33. pass