RL_learning.py 977 B

1234567891011121314151617181920212223242526272829303132333435
  1. import numpy as np
  2. import pandas as pd
  3. from entity import *
  4. class RL():
  5. def __init__(self,
  6. reward_decay = 0.9,
  7. time_step = 144):
  8. self.time_step = time_step
  9. self.gamma = reward_decay
  10. self.q_label = self._build_q_label(self.time_step)
  11. def _build_q_label(self,time_step):
  12. q_label = [dict() for _ in range(time_step)]
  13. return q_label
  14. def learn(self,s,s_,t,t_,r,detal_t):
  15. if not self.q_label[t-1].get(s):
  16. self.q_label[t-1][s] = [ 0, 0]
  17. n = self.q_label[t-1][s][0] + 1
  18. self.q_label[t-1][s][0] = n
  19. try:
  20. q_next = self.q_label[t_-1][s_][1]
  21. except:
  22. q_next = 0
  23. self.q_label[t-1][s][1] += 1/n * ((self.gamma ** detal_t) * q_next + r - self.q_label[t - 1][s][1])
  24. def save_label(self,path=None):
  25. if path:
  26. save(self.q_label,path)
  27. else:
  28. save(self.q_label,"RL_q_label.pkl")