train_data.py 1.0 KB

1234567891011121314151617181920212223242526272829303132333435
  1. from entity import *
  2. import random
  3. import numpy as np
  4. def get_data():
  5. all_data = []
  6. for id in range(300000):
  7. day_data = []
  8. for i in range(random.randint(10,25)):
  9. d_x = random.randint(0,49)
  10. d_y = random.randint(0,49)
  11. driver = Driver(id,d_x,d_y)
  12. to_x = random.randint(0,49)
  13. to_y = random.randint(0,49)
  14. juli = 1000
  15. while juli > 6:
  16. o_x = random.randint(d_x-6,d_x+6)
  17. o_y = random.randint(d_y-6,d_y+6)
  18. juli = abs(o_x - d_x) + abs(o_y - d_y)
  19. order = Order(0,o_x,o_y,to_x,to_y,random.randint(1,144))
  20. match = Match(order,driver)
  21. day_data.append(match)
  22. print('\r'+str(id)+' '+str(i),end='')
  23. all_data.append(day_data)
  24. save(all_data,"train_data_02.pkl")
  25. if __name__ == '__main__':
  26. get_data()
  27. # q = [[6.3724101e-02 ,1.0513991e+02],
  28. # [1.9977689e-03 ,9.5939758e+01],
  29. # [1.5849888e-02, 1.0017528e+02]]
  30. # print(np.max(q, axis=1))