RL_learning.py 887 B

1234567891011121314151617181920212223242526272829303132
  1. import numpy as np
  2. import pandas as pd
  3. from entity import *
  4. class RL():
  5. def __init__(self,
  6. reward_decay = 0.9,
  7. time_step = 144):
  8. self.time_step = time_step
  9. self.gamam = reward_decay
  10. self.q_label = self._build_q_label(time_step)
  11. def _build_q_label(self,time_step):
  12. q_label = [dict() for _ in range(time_step)]
  13. return q_label
  14. def learn(self,s,s_,t,t_,r,detal_t):
  15. if not self.q_label[t-1].get(s):
  16. self.q_label[t-1][s] = [ 0, 0]
  17. n = self.q_label[t-1][s][0] + 1
  18. self.q_label[t-1][s][0] = n
  19. try:
  20. q_next = self.q_label[t_-1][s_][1]
  21. except:
  22. q_next = 0
  23. self.q_label[t-1][s][1] += 1/n * (self.gamam ** detal_t * q_next + r - self.q_label[t-1][s][1])
  24. def save_label(self):
  25. save(self.q_label,"RL_q_label.pkl")