1234567891011121314151617181920212223242526272829303132 |
- import numpy as np
- import pandas as pd
- from entity import *
- class RL():
- def __init__(self,
- reward_decay = 0.9,
- time_step = 144):
- self.time_step = time_step
- self.gamam = reward_decay
- self.q_label = self._build_q_label(time_step)
- def _build_q_label(self,time_step):
- q_label = [dict() for _ in range(time_step)]
- return q_label
- def learn(self,s,s_,t,t_,r,detal_t):
- if not self.q_label[t-1].get(s):
- self.q_label[t-1][s] = [ 0, 0]
- n = self.q_label[t-1][s][0] + 1
- self.q_label[t-1][s][0] = n
- try:
- q_next = self.q_label[t_-1][s_][1]
- except:
- q_next = 0
- self.q_label[t-1][s][1] += 1/n * (self.gamam ** detal_t * q_next + r - self.q_label[t-1][s][1])
- def save_label(self):
- save(self.q_label,"RL_q_label.pkl")
|