get_data.py 1.1 KB

123456789101112131415161718192021222324252627282930313233343536373839
  1. import numpy as np
  2. import pandas as pd
  3. import matplotlib.pyplot as plt
  4. from entity import *
  5. def create_data():
  6. # 区域内当前司机数目
  7. driver_num = np.random.randint(3,100)
  8. # 区域内当前订单数
  9. order_num = np.random.randint(3,100)
  10. # if driver_num > order_num:
  11. # # 价值高的优先
  12. # pass
  13. # elif driver_num <= order_num:
  14. # # 取前order_num个价值高的司机,进行二分图的KM算法求解
  15. # pass
  16. # 司机 与 乘客 的距离
  17. distances = []
  18. for i in range(driver_num):
  19. distance = np.random.normal(3,1,order_num)
  20. # print(distance)
  21. distances.append(distance)
  22. # plt.plot(sorted(distance))
  23. # plt.show()
  24. # 行程距离
  25. travel_distance = np.random.normal(15,6,order_num)
  26. # 订单的价格
  27. rand = np.random.uniform(0.75,0.95,order_num)
  28. price = travel_distance * rand
  29. price = [2 * (i-8) + 10 if i>8 else 10 for i in price]
  30. price = np.array(price) * np.random.uniform(0.88,0.94,order_num)
  31. # print(price)
  32. return list(distances),list(travel_distance),list(price)
  33. if __name__ == '__main__':
  34. create_data()