allen 1 месяц назад
Родитель
Сommit
766d4babfd
1 измененных файлов с 583 добавлено и 0 удалено
  1. 583 0
      TQL.py

+ 583 - 0
TQL.py

@@ -0,0 +1,583 @@
+#!/usr/bin/env python
+# -*- coding: UTF-8 -*-
+"""
+A simple example for Reinforcement Learning using table lookup Q-learning method.
+An agent "o" is on the left of a 1 dimensional world, the treasure is on the rightmost location.
+Run this program and to see how the agent will improve its strategy of finding the treasure.
+
+View more on my tutorial page: https://morvanzhou.github.io/tutorials/
+"""
+import numpy as np
+import pandas as pd
+import time
+import smbus
+import math
+import sympy
+from sympy import asin, cos, sin, acos,tan ,atan
+from DFRobot_RaspberryPi_A02YYUW import DFRobot_A02_Distance as Board
+
+np.random.seed(2)  # reproducible
+
+
+N_STATES = ['raise_time<1','1<=raise_time<2','2<=raise_time<4','raise_time>4','0<=overshoot<0.33','0.33<overshoot<1','10<=setingtime<20','20<=setingtime<30']  ## 1:time<5      2:5.05<time<5.25    3:5.25<time<5.5  4:time>5.5
+goal=16     #goal
+ACTIONS = ['kp+1','kp+0.1','kp+0.01', 'kp+0','kp-0.01','kp-0.1','kp-1','ki+0.1','ki+0.01', 'ki+0','ki-0.01','ki-0.1','kd+0.01', 'kd+0','kd-0.01']     # available actions
+EPSILON = 0.9   # greedy police
+ALPHA = 0.1     # learning rate
+GAMMA = 0.9    # discount factor
+MAX_EPISODES =1 # maximum episodes
+FRESH_TIME = 0.1    # fresh time for one move
+
+kp=0.0
+ki=0.0
+kd=0.0
+count=50
+y=23.5
+class PCA9685:
+
+  # Registers/etc.
+  __SUBADR1            = 0x02
+  __SUBADR2            = 0x03
+  __SUBADR3            = 0x04
+  __MODE1              = 0x00
+  __PRESCALE           = 0xFE
+  __LED0_ON_L          = 0x06
+  __LED0_ON_H          = 0x07
+  __LED0_OFF_L         = 0x08
+  __LED0_OFF_H         = 0x09
+  __ALLLED_ON_L        = 0xFA
+  __ALLLED_ON_H        = 0xFB
+  __ALLLED_OFF_L       = 0xFC
+  __ALLLED_OFF_H       = 0xFD
+
+  def __init__(self, address=0x60, debug=False):
+      self.bus = smbus.SMBus(1)
+      self.address = address
+      self.debug = debug
+      if (self.debug):
+          print("Reseting PCA9685")
+      self.write(self.__MODE1, 0x00)
+
+  def write(self, reg, value):
+      "Writes an 8-bit value to the specified register/address"
+      self.bus.write_byte_data(self.address, reg, value)
+      if (self.debug):
+          print("I2C: Write 0x%02X to register 0x%02X" % (value, reg))
+
+  def read(self, reg):
+      "Read an unsigned byte from the I2C device"
+      result = self.bus.read_byte_data(self.address, reg)
+      if (self.debug):
+          print("I2C: Device 0x%02X returned 0x%02X from reg 0x%02X" % (self.address, result & 0xFF, reg))
+      return result
+
+  def setPWMFreq(self, freq):
+      "Sets the PWM frequency"
+      prescaleval = 25000000.0  # 25MHz
+      prescaleval /= 4096.0  # 12-bit
+      prescaleval /= float(freq)
+      prescaleval -= 1.0
+      if (self.debug):
+          print("Setting PWM frequency to %d Hz" % freq)
+          print("Estimated pre-scale: %d" % prescaleval)
+      prescale = math.floor(prescaleval + 0.5)
+      if (self.debug):
+          print("Final pre-scale: %d" % prescale)
+
+      oldmode = self.read(self.__MODE1);
+      newmode = (oldmode & 0x7F) | 0x10  # sleep
+      self.write(self.__MODE1, newmode)  # go to sleep
+      self.write(self.__PRESCALE, int(math.floor(prescale)))
+      self.write(self.__MODE1, oldmode)
+      time.sleep(0.005)
+      self.write(self.__MODE1, oldmode | 0x80)
+
+  def setPWM(self, channel, on, off):
+      "Sets a single PWM channel"
+      self.write(self.__LED0_ON_L + 4 * channel, on & 0xFF)
+      self.write(self.__LED0_ON_H + 4 * channel, on >> 8)
+      self.write(self.__LED0_OFF_L + 4 * channel, off & 0xFF)
+      self.write(self.__LED0_OFF_H + 4 * channel, off >> 8)
+      if (self.debug):
+          print("channel: %d  LED_ON: %d LED_OFF: %d" % (channel, on, off))
+
+  def setServoPulse(self, channel, pulse):
+      "Sets the Servo Pulse,The PWM frequency must be 50HZ"
+      pulse = pulse * 4096 / 20000  # PWM frequency is 50HZ,the period is 20000us
+      self.setPWM(channel, 0, int(pulse))
+      
+class PIDController:
+    def __init__(self, Kp, Ki, Kd):
+        self.Kp = Kp
+        self.Ki = Ki
+        self.Kd = Kd
+        self.last_error = 0
+        self.integral = 0
+
+    def control(self, error):
+        output = self.Kp * error + self.Ki * self.integral + self.Kd * (error - self.last_error)
+        self.integral += error
+        self.last_error = error
+        return output
+
+
+    def ctr_pwm(self,x,y,z):
+            pwm_1=1750
+            i = 0
+            j = 0
+            # ----------------------------------------
+            r1 = 8
+            r2 = 8
+            # 計算各關節角度
+            x = (x - 320) / 100
+            y = float(y)
+            z = ((240 - z) / 100 )+ 5.3
+            enable = 1
+            print("x=", x)
+            if (x < -19 or x > 19):  # 超出範圍離開
+                print("x over range")
+                enable = 0
+                # break
+            # y= input("input y (0~19):")
+            print("y=", y)
+            if (y < 0 or y > 19):  # 超出範圍離開
+                print("y over range")
+                enable = 0
+                # break
+            # z= input("input z (3~11):")
+            print("z=", z)
+            z2 = z - 3
+            if (z2 < 0 or z2 > 8):  # 超出範圍離開
+                print("z over range")
+                enable = 0
+
+            L = math.sqrt(x * x + y * y) - 8  # 實際夾子前端馬達6.5右邊馬達為1.5公分
+            print("L=", L)
+            if (L < 2):  # 超出範圍離開
+                print("L over range")
+                enable = 0
+
+            h1 = L * L + z2 * z2 + 2 * r1 * L + r1 * r1 - r2 * r2
+            # print("h1=",h1)
+            h2 = -4 * r1 * z2
+            # print("h2=",h2)
+            h3 = L * L + z2 * z2 - 2 * r1 * L + r1 * r1 - r2 * r2
+            # print("h3=",h3)
+            try:
+                Theta = 2 * atan((-h2 + math.sqrt(h2 * h2 - 4 * h1 * h3)) / (2 * h1))  # 無法得出角度離開
+                # print("Theta=",Theta,math.degrees(Theta))
+            except:
+                print("error")
+                enable = 0
+            try:
+                Gamma = asin((z2 - r1 * sin(Theta)) / r2)  # 無法得出角度離開
+                # print("Gamma=",Gamma,math.degrees(Gamma))
+            except:
+                print("error")
+                enable = 0
+            try:
+                tt = x / (r1 * cos(Theta) + r2 * cos(Gamma) + 8)
+                print(tt)
+                if (tt >= 1):
+                    tt = 1
+                if (tt <= -1):
+                    tt = -1
+                Beta = acos(tt)  # 無法得出角度離開
+                # print("Beta=",Beta,math.degrees(Beta))
+            except:
+                print("error")
+                enable = 0
+            if (enable == 1):
+                # Gamma角度轉換為PWM(右邊馬達)
+                pwm_3 = -12.222 * (-math.degrees(Gamma)) + 2300
+                print("pwm_3", pwm_3)
+                error=pwm_3-pwm_1
+                print("error",error)
+                output = self.Kp * error + self.Ki * self.integral + self.Kd * (error - self.last_error)
+                self.integral += error
+                self.last_error = error
+                pwm_1+=output*10
+                if pwm_1>1700 and pwm_1<2300:
+                    pwm_1=pwm_1
+                elif pwm_1>2300:
+                     pwm_1=2300             
+                print("output kp ki kd",pwm_1,self.Kp,self.Ki,self.Kd)
+                pwm.setServoPulse(0,2300)  #前面馬達開夾子
+                pwm.setServoPulse(1,1750)  #右邊馬達45度
+                pwm.setServoPulse(14,1600) #底部馬達置中
+                pwm.setServoPulse(15, pwm_1)
+
+class Any_System:
+    def __init__(self, goal):
+        self.target = goal
+        self.current = 0
+
+    def update(self, control_singal):
+        self.current += control_singal
+        return self.current
+
+    def get_error(self):
+        return self.target - self.current
+
+def train(S,controller,system,num_iterations):
+    errors=[]
+    raise_time=0
+    for _ in range(num_iterations):
+        #error = system.get_error()
+        current =23.5-(board.getDistance()/10)
+        with open('dis1.txt', 'a') as f:
+            f.write(str(current))
+            f.write('\r\n')
+        error=system.target-current # 真實訊號
+        controller.ctr_pwm(320,current,240)
+        #control_signal=controller.control(error)
+        #current=system.update(control_signal)
+        errors.append(error)
+        #time.sleep(0.1)
+        raise_time+=1
+        S_ = N_STATES[3]
+        if error >= 0 and error <= system.target:
+            R = 0
+        elif error > system.target:
+            R = -1
+        elif error < 0:
+            R = -1
+        print(raise_time,current)
+        if current>system.target:
+            print('time',raise_time)
+            if raise_time<10:
+                S_= N_STATES[0]
+                R=5
+            elif (10<=raise_time) and (raise_time<20):
+                S_= N_STATES[1]
+                R=3
+            elif (20<= raise_time) and (raise_time < 40):
+                S_= N_STATES[2]
+                R=2
+            else:
+                S_= N_STATES[3]
+                if  error>0 and error < system.target:
+                    R = 0
+                elif error >system.target:
+                    R = -1
+                elif error <0:
+                    R=-1
+            return S_, R
+    return  S_,R
+
+def train2(S,controller,system,num_iterations):
+    errors=[]
+    current_arr=[]
+    overshoot_value=[]
+    for _ in range(num_iterations):
+        #error = system.get_error()
+        current = 23.5 - (board.getDistance() / 10)
+        with open('dis2.txt', 'a') as f:
+            f.write(str(current))
+            f.write('\r\n')
+        error = system.target - current  # 真實訊號
+        controller.ctr_pwm(320,current,240)
+        #control_signal=controller.control(error)
+        #current=system.update(control_signal)
+        errors.append(error)
+        #time.sleep(0.1)
+        current_arr.append(current)
+    for i in range(num_iterations):
+        if (current_arr[i]-system.target>=0):
+            overshoot_value.append((current_arr[i] - system.target) / system.target)
+        print(i,current_arr[i])
+    #min(temp_arr[9:19])
+    #print(min(temp_arr[9:19]))
+    #overshoot=abs((min(temp_arr[9:19])-30)/30)
+    try:
+        overshoot=max(overshoot_value)
+    except:
+        overshoot =1
+    print(overshoot)
+    if overshoot>=0 and overshoot < 0.0625:
+        print('overshoot success')
+        S_ = N_STATES[4]
+        R = 2
+    elif (0.0625 <= overshoot) and (overshoot < 1):
+        S_ = N_STATES[5]
+        R = 1
+    else:
+        S_ = N_STATES[0]
+        R = 0
+    return S_, R
+
+def train3(S,controller,system,num_iterations):
+    errors=[]
+    current_arr=[]
+    for _ in range(num_iterations):
+        #error = system.get_error()
+        current = 23.5 - (board.getDistance() / 10)
+        with open('dis3.txt', 'a') as f:
+            f.write(str(current))
+            f.write('\r\n')
+        error = system.target - current  # 真實訊號
+        controller.ctr_pwm(320,current,240)
+        #control_signal=controller.control(error)
+        #current=system.update(control_signal)
+        errors.append(error)
+        #time.sleep(0.1)
+        current_arr.append(current)
+    if (abs(current_arr[10]-system.target)) < 5:
+        setingtime =10
+    elif (abs(current_arr[20]-system.target)) < 5:
+        setingtime =20
+    elif (abs(current_arr[30]-system.target)) < 5:
+        setingtime =30
+    else:
+        setingtime=31
+    for i in range(9,49):
+        if (abs(current_arr[i] - system.target))>5:
+            setingtime=31
+
+    print(setingtime)
+    if setingtime>=10 and setingtime < 20:
+        S_ = N_STATES[6]
+        R = 2
+        print('setingtime success')
+        with open('pid.txt', 'a') as f:
+            f.write('kp:')
+            f.write(str(controller.Kp))
+            f.write('ki:')
+            f.write(str(controller.Ki))
+            f.write('kd:')
+            f.write(str(controller.Kd))
+            f.write('\r\n')
+    elif (20 <= setingtime) and (setingtime < 30):
+        S_ = N_STATES[7]
+        R = 1
+    else:
+        S_ = N_STATES[4]
+        R = 0
+    return S_, R
+
+def build_q_table(n_states, actions):
+    try:
+      table = pd.read_csv("/home/pi/pid.csv",index_col=0)
+    except:
+     table = pd.DataFrame(
+        np.zeros((len(n_states), len(actions))),     # q_table initial values
+        columns=actions, index=n_states,   # actions's name
+     )
+    print(table)    # show table
+    return table
+
+
+def choose_action(state, q_table):
+    # This is how to choose an action
+    state_actions = q_table.loc[state, :]
+    if (np.random.uniform() > EPSILON) or ((state_actions == 0).all()):  # act non-greedy or state-action have no value
+        ACT = ['kp+1', 'kp+0.1', 'kp+0.01', 'kp+0', 'kp-0.01', 'kp-0.1', 'kp-1']
+        action_name = np.random.choice(ACT)
+    else:   # act greedy
+        action_name = state_actions.idxmax()    # replace argmax to idxmax as argmax means a different function in newer version of pandas
+    return action_name
+
+def choose_action1(state, q_table):
+    # This is how to choose an action
+    state_actions = q_table.loc[state, :]
+    if (np.random.uniform() > EPSILON) or ((state_actions == 0).all()):  # act non-greedy or state-action have no value
+        ACT = ['kp+1', 'kp+0.1', 'kp+0.01', 'kp+0', 'kp-0.01', 'kp-0.1', 'kp-1']
+        action_name = np.random.choice(ACT)
+    else:   # act greedy
+        action_name = state_actions.idxmax()    # replace argmax to idxmax as argmax means a different function in newer version of pandas
+    return action_name
+
+def choose_action2(state, q_table):
+    # This is how to choose an action
+    state_actions = q_table.loc[state, :]
+    if (np.random.uniform() > EPSILON) or ((state_actions == 0).all()):  # act non-greedy or state-action have no value
+        ACT = [ 'ki+0.1', 'ki+0.01', 'ki+0', 'ki-0.01', 'ki-0.1','kd+0.01', 'kd+0','kd-0.01']
+        action_name = np.random.choice(ACT)
+    else:   # act greedy
+        action_name = state_actions.idxmax()    # replace argmax to idxmax as argmax means a different function in newer version of pandas
+    return action_name
+
+def pid(S,kp):
+    global  goal,count
+    print('kp:',kp)
+    pid_controller = PIDController(kp,0.0,0.0)
+    any_system = Any_System(goal)
+    S_,R = train(S,pid_controller, any_system,count)
+    return  S_,R
+
+def pid1(S,kp):
+    print("overshoot")
+    pid_controller = PIDController(kp, 0.0, 0.0)
+    any_system = Any_System(goal)
+    S_, R = train2(S, pid_controller, any_system, count)
+    print('kp:', kp)
+    return  S_,R
+
+def pid2(S,kp,ki,kd):
+    print("setingtime")
+    pid_controller = PIDController(kp,ki,kd)
+    any_system = Any_System(goal)
+    S_, R = train3(S, pid_controller, any_system, count)
+    print('kp:', kp,'ki',ki,'kd',kd)
+    return  S_,R
+
+def get_env_feedback(S, A):
+    # This is how agent will interact with the environment
+    global kp
+    if A == 'kp+1':    # move right
+          kp+=1
+          S_,R=pid(S,kp)
+    elif A == 'kp+0.1':    # move right
+          kp+=0.1
+          S_,R=pid(S,kp)
+    elif A == 'kp+0.01':    # move right
+          kp+=0.01
+          S_,R=pid(S,kp)
+    elif A=='kp+0':
+          kp=kp+0
+          S_,R= pid(S,kp)
+    elif A == 'kp-0.01':    # move right
+          kp-=0.01
+          S_,R=pid(S,kp)
+    elif A == 'kp-0.1':    # move right
+          kp-=0.1
+          S_,R=pid(S,kp)
+    elif A == 'kp-1':
+          kp-=1
+          S_,R= pid(S,kp)
+    return S_, R
+
+def get_env_feedback1(S, A):
+    # This is how agent will interact with the environment
+    global kp
+    if A == 'kp+1':    # move right
+          kp+=1
+          S_,R=pid1(S,kp)
+    elif A == 'kp+0.1':    # move right
+          kp+=0.1
+          S_,R=pid1(S,kp)
+    elif A == 'kp+0.01':    # move right
+          kp+=0.01
+          S_,R=pid1(S,kp)
+    elif A=='kp+0':
+          kp=kp+0
+          S_,R= pid1(S,kp)
+    elif A == 'kp-0.01':    # move right
+          kp-=0.01
+          S_,R=pid1(S,kp)
+    elif A == 'kp-0.1':    # move right
+          kp-=0.1
+          S_,R=pid1(S,kp)
+    elif A == 'kp-1':
+          kp-=1
+          S_,R= pid1(S,kp)
+    return S_, R
+
+def get_env_feedback2(S, A):
+    # This is how agent will interact with the environment
+    global ki
+    global kp
+    global kd
+    if A == 'ki+0.1':    # move right
+          ki+=0.1
+          S_,R=pid2(S,kp,ki,kd)
+    elif A == 'ki+0.01':    # move right
+          ki+=0.01
+          S_,R=pid2(S,kp,ki,kd)
+    elif A=='ki+0':
+          ki=ki+0
+          S_,R= pid2(S,kp,ki,kd)
+    elif A == 'ki-0.01':    # move right
+          ki-=0.01
+          S_,R=pid2(S,kp,ki,kd)
+    elif A == 'ki-0.1':    # move right
+          ki-=0.1
+          S_,R=pid2(S,kp,ki,kd)
+    elif A == 'kd+0.01':    # move right
+          kd+=0.01
+          S_,R=pid2(S,kp,ki,kd)
+    elif A=='kd+0':
+          kd=kd+0
+          S_,R= pid2(S,kp,ki,kd)
+    elif A == 'kd-0.01':    # move right
+          kd-=0.01
+          S_,R=pid2(S,kp,ki,kd)
+    return S_, R
+
+def update_env(S, episode, step_counter):
+    # This is how environment be updated
+    interaction = 'Episode %s: raise_time= %s' % (episode + 1,S)
+    #print('\r{}'.format(interaction), end='')
+    #print('Episode %s: raise_time= %s\r\n' % (episode + 1,S))
+
+
+def rl():
+    # main part of RL loop
+    global x, y, z, w
+    q_table = build_q_table(N_STATES, ACTIONS)
+    for episode in range(MAX_EPISODES):
+        S = N_STATES[3]
+        is_terminated = False
+        while not is_terminated:
+           x = 320
+           z = 240
+           y = 15.5
+           pwm.setServoPulse(0, 1100)  # 前面馬達開夾子
+           pwm.setServoPulse(1, 1750)  # 右邊馬達45度
+           pwm.setServoPulse(14, 1600)  # 底部馬達置中
+           pwm.setServoPulse(15, 1700)  # 左邊馬達垂直
+           time.sleep(1)
+           pwm.setServoPulse(0, 0)
+           pwm.setServoPulse(1, 0)
+           pwm.setServoPulse(14, 0)
+           pwm.setServoPulse(15, 0)
+           #update_env(S, episode, step_counter)
+           if S==N_STATES[3] or S==N_STATES[2] or S==N_STATES[1] or S==N_STATES[0]:
+              A = choose_action(S, q_table)
+              S_, R = get_env_feedback(S, A)  # take action & get next state and reward
+              q_predict = q_table.loc[S, A]
+              q_target = R + GAMMA * q_table.loc[S_, :].max()   # next state is not terminal
+              q_table.loc[S, A] += ALPHA * (q_target - q_predict)  # update
+              print(q_table)
+              S = S_  # move to next state
+              #update_env(S, episode, step_counter)
+              #step_counter += 1
+              if S==N_STATES[0]:
+                  S=N_STATES[5]
+           elif  S == N_STATES[4] or S == N_STATES[5]:
+               print("raise_time success")
+               A = choose_action1(S, q_table)
+               S_, R = get_env_feedback1(S, A)  # take action & get next state and reward
+               q_predict = q_table.loc[S, A]
+               q_target = R + GAMMA * q_table.loc[S_, :].max()  # next state is not terminal
+               q_table.loc[S, A] += ALPHA * (q_target - q_predict)  # update
+               S = S_  # move to next state
+               #update_env(S, episode, step_counter)
+               #step_counter += 1
+               if S==N_STATES[4]:
+                  S=N_STATES[7]
+           elif  S == N_STATES[6] or S == N_STATES[7] :
+               A = choose_action2(S, q_table)
+               S_, R = get_env_feedback2(S, A)  # take action & get next state and reward
+               q_predict = q_table.loc[S, A]
+               q_target = R + GAMMA * q_table.loc[S_, :].max()  # next state is not terminal
+               q_table.loc[S, A] += ALPHA * (q_target - q_predict)  # update
+               S = S_  # move to next state
+               if S == N_STATES[6]:
+                 is_terminated = True
+               #update_env(S, episode, step_counter )
+    return q_table
+
+
+
+if __name__ == "__main__":
+    pwm = PCA9685(0x60, debug=False)
+    pwm.setPWMFreq(50)
+    board = Board()
+    dis_min = 0   #Minimum ranging threshold: 0mm
+    dis_max = 4500 #Highest ranging threshold: 4500mm
+    board.set_dis_range(dis_min, dis_max)
+    q_table = rl()
+    print('\r\nQ-table:\n')
+    print(q_table)
+    q_table.to_csv("/home/pi/pid.csv")