allen пре 3 месеци
родитељ
комит
e9969fc845
1 измењених фајлова са 516 додато и 0 уклоњено
  1. 516 0
      STEPCTQL.py

+ 516 - 0
STEPCTQL.py

@@ -0,0 +1,516 @@
+#!/usr/bin/env python
+# -*- coding: UTF-8 -*-
+"""
+A simple example for Reinforcement Learning using table lookup Q-learning method.
+An agent "o" is on the left of a 1 dimensional world, the treasure is on the rightmost location.
+Run this program and to see how the agent will improve its strategy of finding the treasure.
+
+View more on my tutorial page: https://morvanzhou.github.io/tutorials/
+"""
+import numpy as np
+import pandas as pd
+import time
+import smbus
+import math
+import sympy
+from sympy import asin, cos, sin, acos,tan ,atan
+#from DFRobot_RaspberryPi_A02YYUW import DFRobot_A02_Distance as Board
+import paho.mqtt.client as mqtt
+import json
+import RPi.GPIO as GPIO
+from time import sleep
+np.random.seed(2)  # reproducible
+
+#設定步進馬達
+GPIO.setmode(GPIO.BCM)
+DIR = 23
+STEP = 24
+CW = 1
+CCW = 0
+GPIO.setup(DIR, GPIO.OUT)
+GPIO.setup(STEP, GPIO.OUT)
+GPIO.output(DIR, CW)
+
+N_STATES = ['raise_time<1','1<=raise_time<2','2<=raise_time<4','raise_time>4','0<=overshoot<0.33','0.33<overshoot<1','10<=setingtime<20','20<=setingtime<30']  ## 1:time<5      2:5.05<time<5.25    3:5.25<time<5.5  4:time>5.5
+goal=320    #goal
+ACTIONS = ['kp+1','kp+0.1','kp+0.01', 'kp+0','kp-0.01','kp-0.1','kp-1','ki+0.1','ki+0.01', 'ki+0','ki-0.01','ki-0.1','kd+0.1', 'kd+0','kd-0.1']     # available actions
+EPSILON = 0.9   # greedy police
+ALPHA = 0.1     # learning rate
+GAMMA = 0.9    # discount factor
+MAX_EPISODES =1 # maximum episodes
+FRESH_TIME = 0.1    # fresh time for one move
+
+kp=0.0
+ki=0.0
+kd=0.0
+count=50
+x=0
+jsonmsg=""
+S_=""
+
+def step(CW,plus):
+    global DIR,STEP
+    GPIO.output(DIR, CW)
+    for x in range(plus):
+		      # Set one coil winding to high
+			    GPIO.output(STEP,GPIO.HIGH)
+			    # Allow it to get there.
+			    sleep(0.001) # Dictates how fast stepper motor will run
+			    # Set coil winding to low
+			    GPIO.output(STEP,GPIO.LOW)
+			    sleep(0.001)
+    
+def on_connect(client, userdata, flags, rc):
+    print("Connected with result code " + str(rc))
+    client.subscribe("b8:27:eb:eb:21:13/Log", qos=2)
+
+    # 當接收到從伺服器發送的訊息時要進行的動作
+
+
+def on_message(client, userdata, msg):
+    # 轉換編碼utf-8才看得懂中文
+    global  jsonmsg
+    global  x
+    msg.payload = msg.payload.decode('utf-8')
+    jsonmsg = json.loads(msg.payload)
+    x=int(jsonmsg['x'])
+
+      
+class PIDController:
+    def __init__(self, Kp, Ki, Kd):
+        self.Kp = Kp
+        self.Ki = Ki
+        self.Kd = Kd
+        self.last_error = 0
+        self.integral = 0
+
+    def control(self, error):
+        output = self.Kp * error + self.Ki * self.integral + self.Kd * (error - self.last_error)
+        self.integral += error
+        self.last_error = error 
+        return output
+    
+   
+class Any_System:
+    def __init__(self, goal):
+        self.target = goal
+        self.current = 0
+
+    def update(self, control_singal):
+        self.current += control_singal
+        return self.current
+
+    def get_error(self):
+        return self.target - self.current
+
+def train(controller,system,num_iterations):
+    global  jsonmsg
+    global x
+    global S_,R
+    errors=[]
+    raise_time=0
+    cont=0
+    rtz=0
+    for _ in range(num_iterations):
+        #error = system.get_error()
+        raise_time+=1
+        current=x
+        error=system.target-current # 真實訊號
+        output=controller.control(error)
+        output=output*4
+        print(raise_time,current,output)
+        if abs(rtz)<7750:
+         if output>0:
+           step(1,int(output))
+           rtz-=int(output)
+         else:
+           step(0,int(-1*output))
+           rtz+=int(-1*output)
+         time.sleep(0.5)
+        #control_signal=controller.control(error)
+        #current=system.update(control_signal)
+        if ((current-system.target)>=0):
+            cont=raise_time
+            break
+        else:
+            cont=50
+    if cont<=10:
+                S_= N_STATES[0]
+                R=5
+                print('raise_time success')
+    elif (10<cont) and (cont<=20):
+                S_= N_STATES[1]
+                R=((15-cont)/10)
+    elif (20< cont) and (cont <= 40):
+                S_= N_STATES[2]
+                R=((30-cont)/10)
+    else:
+                S_= N_STATES[3]
+                R=-(error/100)
+    if rtz>0:
+           step(1,int(rtz))
+    else:
+           step(0,int(-1*rtz))
+    time.sleep(3)
+    print("rtz:",rtz)       
+    return  S_,R
+
+def train2(controller,system,num_iterations):
+    global x
+    global S_,R
+    errors=[]
+    current_arr=[]
+    ot=0
+    over_time=0
+    over_shoot=0
+    rtz=0
+    for _ in range(num_iterations):
+        #error = system.get_error()
+        #current = 23.5 - (board.getDistance() / 10)
+        current = x
+        value=current
+        error = system.target - current  # 真實訊號
+        output=controller.control(error)
+        output=output*4
+        if abs(rtz)<7750:
+         if output>0:
+           step(1,int(output))
+           rtz-=int(output)
+         else:
+           step(0,int(-1*output))
+           rtz+=int(-1*output)
+         time.sleep(0.5)
+        over_time+=1
+        if(value>ot):
+           ot=value
+        print(over_time,ot,output,rtz)
+    over_shoot=float(abs(ot-system.target))/320
+    print("overshoot",str(over_shoot))
+    if over_shoot>=0 and over_shoot < 0.0625:
+        print('overshoot success')
+        S_ = N_STATES[4]
+        R = 5
+    elif (0.0625 <= over_shoot) and (over_shoot <1):
+        S_ = N_STATES[5]
+        R = -1*over_shoot
+    else: 
+        S_ = N_STATES[0]
+        R =  0
+    if rtz>0:
+           step(1,int(rtz))
+    else:
+           step(0,int(-1*rtz))
+    time.sleep(3)
+    print("rtz:",rtz)   
+    return S_, R
+
+def train3(controller,system,num_iterations):
+    global x
+    global S_,R
+    errors=[]
+    cont=0
+    setingtime=0
+    con=0
+    rtz=0
+    print("3")
+    for _ in range(num_iterations):
+        cont=cont+1
+        current = x
+        error = system.target - current  # 真實訊號
+        output=controller.control(error)
+        output=output*4
+        if abs(rtz)<7750: 
+          if output>0:
+           step(1,int(output))
+           rtz-=int(output)
+          else:
+           step(0,int(-1*output))
+           rtz+=int(-1*output)
+          time.sleep(0.5)    
+        print(cont,current,output)
+        if ((-1*error)>=0) and ((-1*error)<=5) and (con==0):
+            setingtime =cont
+            con=1
+        elif(con==0):
+            setingtime=40
+    print(setingtime)
+    if setingtime>=0 and setingtime < 10:
+        S_ = N_STATES[6]
+        R = 10
+        print('setingtime success')
+        with open('pid.txt', 'a') as f:
+            f.write('kp:')
+            f.write(str(controller.Kp))
+            f.write('ki:')
+            f.write(str(controller.Ki))
+            f.write('kd:')
+            f.write(str(controller.Kd))
+            f.write('\r\n')
+    elif (10 <= setingtime) and (setingtime <= 40):
+        S_ = N_STATES[7]
+        R = ((25-setingtime)/10)
+    else:
+        S_ = N_STATES[4]
+        R = -(error/100)
+    if rtz>0:
+           step(1,int(rtz))
+    else:
+           step(0,int(-1*rtz))
+    time.sleep(3)
+    print("rtz:",rtz)   
+    return S_, R
+
+def build_q_table(n_states, actions):
+    try:
+      table = pd.read_csv("/home/pi/pid.csv",index_col=0)
+    except:
+     table = pd.DataFrame(
+        np.zeros((len(n_states), len(actions))),     # q_table initial values
+        columns=actions, index=n_states,   # actions's name
+     )
+    print(table)    # show table
+    return table
+
+
+def choose_action(state, q_table):
+    # This is how to choose an action
+    state_actions = q_table.loc[state, :]
+    if (np.random.uniform() > EPSILON) or ((state_actions == 0).all()):  # act non-greedy or state-action have no value
+        ACT = ['kp+1', 'kp+0.1', 'kp+0.01', 'kp+0', 'kp-0.01', 'kp-0.1', 'kp-1']
+        action_name = np.random.choice(ACT)
+    else:   # act greedy
+         action_name = state_actions.idxmax()    # replace argmax to idxmax as argmax means a different function in newer version of pandas
+    return action_name
+
+def choose_action1(state, q_table):
+    # This is how to choose an action
+    state_actions = q_table.loc[state, :]
+    if (np.random.uniform() > EPSILON) or ((state_actions == 0).all()):  # act non-greedy or state-action have no value
+        ACT = ['kp+1', 'kp+0.1', 'kp+0.01', 'kp+0', 'kp-0.01', 'kp-0.1', 'kp-1']
+        action_name = np.random.choice(ACT)
+    else:   # act greedy
+        action_name = state_actions.idxmax()    # replace argmax to idxmax as argmax means a different function in newer version of pandas
+    return action_name
+
+def choose_action2(state, q_table):
+    # This is how to choose an action
+    state_actions = q_table.loc[state, :]
+    print("3")
+    if (np.random.uniform() > EPSILON) or ((state_actions == 0).all()):  # act non-greedy or state-action have no value
+        ACT = [ 'ki+0.1', 'ki+0.01', 'ki+0', 'ki-0.1', 'ki-0.01']
+        action_name = np.random.choice(ACT)
+    else:   # act greedy
+        action_name = state_actions.idxmax()    # replace argmax to idxmax as argmax means a different function in newer version of pandas
+    print(action_name)    
+    return action_name
+
+def pid(kp):
+    global  goal,count
+    global S_
+    R=0
+    print("raisetime")
+    pid_controller = PIDController(kp,0.0,0.0)
+    any_system = Any_System(goal)
+    S_,R = train(pid_controller, any_system,count)
+    print('kp:',kp)
+    return  S_,R
+
+def pid1(kp):
+    global S_
+    R=0
+    print("overshoot")
+    pid_controller = PIDController(kp, 0.0, 0.0)
+    any_system = Any_System(goal)
+    S_, R = train2(pid_controller, any_system, count)
+    print('kp:', kp)
+    return  S_,R
+
+def pid2(kp,ki,kd):
+    global S_
+    R=0
+    print("setingtime")
+    pid_controller = PIDController(kp,ki,kd)
+    any_system = Any_System(goal)
+    S_, R = train3(pid_controller, any_system, count)
+    print('kp:', kp,'ki',ki,'kd',kd)
+    return  S_,R
+
+def get_env_feedback(S, A):
+    # This is how agent will interact with the environment
+    global kp,S_
+    R=0
+    if A == 'kp+1':    # move right
+          kp+=1
+          S_,R=pid(kp)
+    elif A == 'kp+0.1':    # move right
+          kp+=0.1
+          S_,R=pid(kp)
+    elif A == 'kp+0.01':    # move right
+          kp+=0.01
+          S_,R=pid(kp)
+    elif A=='kp+0':
+          kp=kp+0
+          S_,R= pid(kp)
+    elif A == 'kp-0.01':    # move right
+          kp-=0.01
+          S_,R=pid(kp)
+    elif A == 'kp-0.1':    # move right
+          kp-=0.1
+          S_,R=pid(kp)
+    elif A == 'kp-1':
+          kp-=1
+          S_,R= pid(kp)
+          
+    return S_, R
+
+def get_env_feedback1(S, A):
+    # This is how agent will interact with the environment
+    global kp,S_
+    R=0
+    if A == 'kp+1':    # move right
+          kp+=1
+          S_,R=pid1(kp)
+    elif A == 'kp+0.1':    # move right
+          kp+=0.1
+          S_,R=pid1(kp)
+    elif A == 'kp+0.01':    # move right
+          kp+=0.01
+          S_,R=pid1(kp)
+    elif A=='kp+0':
+          kp=kp+0
+          S_,R= pid1(kp)
+    elif A == 'kp-0.01':    # move right
+          kp-=0.01
+          S_,R=pid1(kp)
+    elif A == 'kp-0.1':    # move right
+          kp-=0.1
+          S_,R=pid1(kp)
+    elif A == 'kp-1':    # move right
+          kp-=1
+          S_,R=pid1(kp)
+    return S_, R
+
+def get_env_feedback2(S, A):
+    # This is how agent will interact with the environment
+    global ki
+    global kp
+    global kd,S_
+    R=0
+    if A == 'ki+0.1':    # move right
+          ki+=0.1
+          S_,R=pid2(kp,ki,kd)
+    elif A == 'ki+0.01':    # move right
+          ki+=0.01
+          S_,R=pid2(kp,ki,kd)
+    elif A=='ki+0':
+          ki=ki+0
+          S_,R= pid2(kp,ki,kd)
+    elif A == 'ki-0.1':    # move right
+          ki-=0.1
+          S_,R=pid2(kp,ki,kd)
+    elif A == 'ki-0.01':    # move right
+          ki-=0.01
+          S_,R=pid2(kp,ki,kd)
+    elif A == 'kd+0.1':    # move right
+          kd+=0.1
+          S_,R=pid2(kp,ki,kd)
+    elif A=='kd+0':
+          kd=kd+0
+          S_,R= pid2(kp,ki,kd)
+    elif A == 'kd-0.1':    # move right
+          kd-=0.1
+          S_,R=pid2(kp,ki,kd)
+    return S_, R
+
+def update_env(S, episode, step_counter):
+    # This is how environment be updated
+    interaction = 'Episode %s: raise_time= %s' % (episode + 1,S)
+    #print('\r{}'.format(interaction), end='')
+    #print('Episode %s: raise_time= %s\r\n' % (episode + 1,S))
+
+
+def rl():
+    global  x,y,z
+    # main part of RL loop
+    q_table = build_q_table(N_STATES, ACTIONS)
+    for episode in range(MAX_EPISODES):
+        S = N_STATES[3]
+        is_terminated = False
+        step(1,7750)
+        time.sleep(5)
+        while not is_terminated:
+           #update_env(S, episode, step_counter)
+           if S==N_STATES[3] or S==N_STATES[2] or S==N_STATES[1] or S==N_STATES[0]:
+              A = choose_action(S, q_table)
+              S_, R = get_env_feedback(S, A)  # take action & get next state and reward
+              q_predict = q_table.loc[S, A]
+              q_target = R + GAMMA * q_table.loc[S_, :].max()   # next state is not terminal
+              q_table.loc[S, A] += ALPHA * (q_target - q_predict)  # update
+              print(q_table)
+              S = S_  # move to next state
+              #update_env(S, episode, step_counter)
+              #step_counter += 1
+              if S==N_STATES[0]:
+                  S=N_STATES[5]
+           elif  S == N_STATES[4] or S == N_STATES[5]:
+               A = choose_action1(S, q_table)
+               S_, R = get_env_feedback1(S, A)  # take action & get next state and reward
+               q_predict = q_table.loc[S, A]
+               print(q_table)
+               q_target = R + GAMMA * q_table.loc[S_, :].max()  # next state is not terminal
+               q_table.loc[S, A] += ALPHA * (q_target - q_predict)  # update
+               S = S_  # move to next state
+               #update_env(S, episode, step_counter)
+               #step_counter += 1
+               if S==N_STATES[4]:
+                  S=N_STATES[7]
+           elif  S == N_STATES[6] or S == N_STATES[7] :
+               A = choose_action2(S, q_table)
+               S_, R = get_env_feedback2(S, A)  # take action & get next state and reward
+               q_predict = q_table.loc[S, A]
+               print(q_table)
+               q_target = R + GAMMA * q_table.loc[S_, :].max()  # next state is not terminal
+               q_table.loc[S, A] += ALPHA * (q_target - q_predict)  # update
+               S = S_  # move to next state
+               if S == N_STATES[6]:
+                 is_terminated = True
+               #update_env(S, episode, step_counter )
+    return q_table
+
+
+def test_pid():
+  global x
+  step(1,7750)
+  time.sleep(5)
+  p_prev_error=0
+  p_integral=0
+  # ----------------------------------------
+  #計算各關節角度
+  while True:
+    current= x
+    p_error=320-current
+    p_output=5.32*p_error+ 0.0*p_integral+0.00*(p_error - p_prev_error)
+    p_integral +=p_error
+    p_prev_error=p_error 
+    if p_output>0:
+           step(1,int(p_output))
+    else:
+           step(0,int(-1*p_output))
+    time.sleep(0.1)
+    
+    
+if __name__ == "__main__":
+    client = mqtt.Client()
+    client.username_pw_set(username='aisky-client', password='aiskyc')
+    client.connect("60.250.156.234", 1883, 60)
+    client.on_connect = on_connect
+    client.on_message = on_message
+    client.loop_start()
+    a="T"
+    if (a=="a"):
+      q_table = rl()
+      print('\r\nQ-table:\n')
+      print(q_table)
+      q_table.to_csv("/home/pi/pid.csv")
+    else:
+      test_pid()