123456789101112131415161718192021222324252627282930313233343536373839404142434445 |
- from myDQN.DQN import DQN
- import tensorflow as tf
- import numpy as np
- import pandas as pd
- from entity import *
- n_actions = 2
- n_features = 3
- max_x = 50
- max_y = 50
- max_time = 144
- def train():
- data = load('../train_data/train_data.pkl')
- print("数据量:",sum(len(i) for i in data))
- step = 0
- RL = DQN(n_actions,n_features)
- for d in data:
- for match in d:
- s_x = match.driver.x / max_x
- s_y = match.driver.y / max_y
- s_time = match.order.order_time / max_time
- _s_x = match.order.to_x / max_x
- _s_y = match.order.to_y / max_y
- _s_time = match.order.arrive_time / max_time
- travel_time = match.order.travel_time
- reward = match.money
- RL.store_transition((s_x,s_y,s_time),0,reward,travel_time,(_s_x,_s_y,_s_time))
- if (step > 200) and (step % 10 == 0):
- RL.learn()
- step += 1
- RL.plot_cost()
- if __name__ == '__main__':
- train()
- pass
|