train.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. import numpy as np
  2. import pandas as pd
  3. from entity import *
  4. from didi_RL.RL_learning import RL
  5. n_actions = 2
  6. n_features = 3
  7. train_step = 10
  8. max_x = 30
  9. max_y = 30
  10. max_time = 144
  11. gamma = 0.9
  12. # max_x = 10
  13. # max_y = 10
  14. # max_time = 20
  15. # gamma = 0.9
  16. def train():
  17. # data = load('../train_data/train_data03.pkl')
  18. data = load('../train_data111.pkl')
  19. print("数据量:",sum(len(i) for i in data))
  20. step = 0
  21. learn_num = 0
  22. rl = RL(time_step=max_time)
  23. for i in range(len(data)-1,-1,-1):
  24. for match in data[i]:
  25. travel_time = match.travel_time
  26. rl.learn(s = (match.driver.x,match.driver.y),
  27. s_ = (match.order.to_x,match.order.to_y),
  28. t = match.order.order_time,
  29. t_ = match.arrive_time,
  30. r = (match.money / travel_time) * sum([gamma ** i for i in range(travel_time)]),
  31. detal_t = travel_time
  32. )
  33. print("ok")
  34. # rl.save_label('RL_q_label02.pkl')
  35. rl.save_label('RL_q_label03.pkl')
  36. q_label_path = "RL_q_label.pkl"
  37. def predict(s,s_,t,t_,r,detal_t):
  38. q_label = load(q_label_path)
  39. q_now = q_label[t-1][s][1]
  40. q_next = q_label[t_-1][s_][1]
  41. V = gamma ** detal_t * q_next - q_now + r
  42. return V
  43. if __name__ == '__main__':
  44. # train()
  45. a = load('RL_q_label03.pkl')
  46. xx = []
  47. yy = []
  48. Z = []
  49. import matplotlib.pyplot as plt
  50. plt.xlabel('X')
  51. plt.ylabel('Y')
  52. plt.xlim(xmax=30, xmin=0)
  53. plt.ylim(ymax=30, ymin=0)
  54. colors1 = '#00CED1' # 点的颜色
  55. colors2 = '#DC143C'
  56. area = np.pi * 8 # 点面积
  57. # 画散点图
  58. for k,v in a[108].items():
  59. # if 16<=k[0]<=24 and 16<=k[1]<=24:
  60. print(k,v)
  61. xx.append(k[0])
  62. yy.append(k[1])
  63. Z.append(v[1])
  64. colors = colors2 if v[1]>60 else colors1
  65. plt.scatter(k[0], k[1], s=area, c=colors, alpha=0.4)
  66. # plt.plot([0, 9.5], [9.5, 0], linewidth='0.5', color='#000000')
  67. plt.legend()
  68. # 三维图
  69. # from mpl_toolkits.mplot3d import Axes3D
  70. # fig = plt.figure() # 定义新的三维坐标轴
  71. # ax3 = plt.axes(projection='3d')
  72. # ax3.plot_trisurf(xx, yy, Z,cmap='rainbow')
  73. # # 作图
  74. # ax3.plot_surface(X, Y, np.array(Z), cmap='rainbow')
  75. # ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap='rainbow) #等高线图,要设置offset,为Z的最小值
  76. plt.show()
  77. pass