CTQL.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557
  1. #!/usr/bin/env python
  2. # -*- coding: UTF-8 -*-
  3. """
  4. A simple example for Reinforcement Learning using table lookup Q-learning method.
  5. An agent "o" is on the left of a 1 dimensional world, the treasure is on the rightmost location.
  6. Run this program and to see how the agent will improve its strategy of finding the treasure.
  7. View more on my tutorial page: https://morvanzhou.github.io/tutorials/
  8. """
  9. import numpy as np
  10. import pandas as pd
  11. import time
  12. import smbus
  13. import math
  14. import sympy
  15. from sympy import asin, cos, sin, acos,tan ,atan
  16. #from DFRobot_RaspberryPi_A02YYUW import DFRobot_A02_Distance as Board
  17. import paho.mqtt.client as mqtt
  18. import json
  19. np.random.seed(2) # reproducible
  20. N_STATES = ['raise_time<1','1<=raise_time<2','2<=raise_time<4','raise_time>4','0<=overshoot<0.33','0.33<overshoot<1','10<=setingtime<20','20<=setingtime<30'] ## 1:time<5 2:5.05<time<5.25 3:5.25<time<5.5 4:time>5.5
  21. goal=320 #goal
  22. ACTIONS = ['kp+1','kp+0.1','kp+0.01', 'kp+0','kp-0.01','kp-0.1','kp-1','ki+0.1','ki+0.01', 'ki+0','ki-0.01','ki-0.1','kd+0.1', 'kd+0','kd-0.1'] # available actions
  23. EPSILON = 0.9 # greedy police
  24. ALPHA = 0.1 # learning rate
  25. GAMMA = 0.9 # discount factor
  26. MAX_EPISODES =1 # maximum episodes
  27. FRESH_TIME = 0.1 # fresh time for one move
  28. kp=0.0
  29. ki=0.0
  30. kd=0.0
  31. count=50
  32. x=0
  33. jsonmsg=""
  34. pwm_1=1600
  35. S_=""
  36. p_prev_error=0
  37. p_integral=0
  38. def on_connect(client, userdata, flags, rc):
  39. print("Connected with result code " + str(rc))
  40. client.subscribe("b8:27:eb:eb:21:13/Log", qos=2)
  41. # 當接收到從伺服器發送的訊息時要進行的動作
  42. def on_message(client, userdata, msg):
  43. # 轉換編碼utf-8才看得懂中文
  44. global jsonmsg
  45. global x
  46. msg.payload = msg.payload.decode('utf-8')
  47. jsonmsg = json.loads(msg.payload)
  48. x=int(jsonmsg['x'])
  49. class PCA9685:
  50. # Registers/etc.
  51. __SUBADR1 = 0x02
  52. __SUBADR2 = 0x03
  53. __SUBADR3 = 0x04
  54. __MODE1 = 0x00
  55. __PRESCALE = 0xFE
  56. __LED0_ON_L = 0x06
  57. __LED0_ON_H = 0x07
  58. __LED0_OFF_L = 0x08
  59. __LED0_OFF_H = 0x09
  60. __ALLLED_ON_L = 0xFA
  61. __ALLLED_ON_H = 0xFB
  62. __ALLLED_OFF_L = 0xFC
  63. __ALLLED_OFF_H = 0xFD
  64. def __init__(self, address=0x60, debug=False):
  65. self.bus = smbus.SMBus(1)
  66. self.address = address
  67. self.debug = debug
  68. if (self.debug):
  69. print("Reseting PCA9685")
  70. self.write(self.__MODE1, 0x00)
  71. def write(self, reg, value):
  72. "Writes an 8-bit value to the specified register/address"
  73. self.bus.write_byte_data(self.address, reg, value)
  74. if (self.debug):
  75. print("I2C: Write 0x%02X to register 0x%02X" % (value, reg))
  76. def read(self, reg):
  77. "Read an unsigned byte from the I2C device"
  78. result = self.bus.read_byte_data(self.address, reg)
  79. if (self.debug):
  80. print("I2C: Device 0x%02X returned 0x%02X from reg 0x%02X" % (self.address, result & 0xFF, reg))
  81. return result
  82. def setPWMFreq(self, freq):
  83. "Sets the PWM frequency"
  84. prescaleval = 25000000.0 # 25MHz
  85. prescaleval /= 4096.0 # 12-bit
  86. prescaleval /= float(freq)
  87. prescaleval -= 1.0
  88. if (self.debug):
  89. print("Setting PWM frequency to %d Hz" % freq)
  90. print("Estimated pre-scale: %d" % prescaleval)
  91. prescale = math.floor(prescaleval + 0.5)
  92. if (self.debug):
  93. print("Final pre-scale: %d" % prescale)
  94. oldmode = self.read(self.__MODE1);
  95. newmode = (oldmode & 0x7F) | 0x10 # sleep
  96. self.write(self.__MODE1, newmode) # go to sleep
  97. self.write(self.__PRESCALE, int(math.floor(prescale)))
  98. self.write(self.__MODE1, oldmode)
  99. time.sleep(0.005)
  100. self.write(self.__MODE1, oldmode | 0x80)
  101. def setPWM(self, channel, on, off):
  102. "Sets a single PWM channel"
  103. self.write(self.__LED0_ON_L + 4 * channel, on & 0xFF)
  104. self.write(self.__LED0_ON_H + 4 * channel, on >> 8)
  105. self.write(self.__LED0_OFF_L + 4 * channel, off & 0xFF)
  106. self.write(self.__LED0_OFF_H + 4 * channel, off >> 8)
  107. if (self.debug):
  108. print("channel: %d LED_ON: %d LED_OFF: %d" % (channel, on, off))
  109. def setServoPulse(self, channel, pulse):
  110. "Sets the Servo Pulse,The PWM frequency must be 50HZ"
  111. pulse = pulse * 4096 / 20000 # PWM frequency is 50HZ,the period is 20000us
  112. self.setPWM(channel, 0, int(pulse))
  113. class PIDController:
  114. def __init__(self, Kp, Ki, Kd):
  115. self.Kp = Kp
  116. self.Ki = Ki
  117. self.Kd = Kd
  118. self.last_error = 0
  119. self.integral = 0
  120. def control(self, error):
  121. output = self.Kp * error + self.Ki * self.integral + self.Kd * (error - self.last_error)
  122. self.integral += error
  123. self.last_error = error
  124. return output
  125. class Any_System:
  126. def __init__(self, goal):
  127. self.target = goal
  128. self.current = 0
  129. def update(self, control_singal):
  130. self.current += control_singal
  131. return self.current
  132. def get_error(self):
  133. return self.target - self.current
  134. def train(controller,system,num_iterations):
  135. global jsonmsg
  136. global x
  137. global S_,R
  138. errors=[]
  139. raise_time=0
  140. cont=0
  141. for _ in range(num_iterations):
  142. #error = system.get_error()
  143. raise_time+=1
  144. current=x
  145. print(raise_time,current)
  146. error=system.target-current # 真實訊號
  147. output=controller.control(error)
  148. current+= output
  149. pwm_3=0.6156*current+1396
  150. if(pwm_3<1500):
  151. pwm_3=1500
  152. elif(pwm_3>1750):
  153. pwm_3=1750
  154. pwm.setServoPulse(14,pwm_3) #底部馬達置中
  155. time.sleep(0.5)
  156. #control_signal=controller.control(error)
  157. #current=system.update(control_signal)
  158. if ((current-system.target)>0):
  159. cont=raise_time
  160. break
  161. else:
  162. cont=50
  163. if cont<=10:
  164. S_= N_STATES[0]
  165. R=5
  166. print('raise_time success')
  167. elif (10<cont) and (cont<=20):
  168. S_= N_STATES[1]
  169. R=2
  170. elif (20< cont) and (cont <= 40):
  171. S_= N_STATES[2]
  172. R=1
  173. else:
  174. S_= N_STATES[3]
  175. R=1.5-(error/100)
  176. return S_,R
  177. def train2(controller,system,num_iterations):
  178. global x
  179. global S_,R
  180. errors=[]
  181. current_arr=[]
  182. ot=0
  183. over_time=0
  184. over_shoot=0
  185. for _ in range(num_iterations):
  186. #error = system.get_error()
  187. #current = 23.5 - (board.getDistance() / 10)
  188. current = x
  189. value=current
  190. error = system.target - current # 真實訊號
  191. output=controller.control(error)
  192. current+= output
  193. pwm_3=0.6156*current+1396
  194. if(pwm_3<1500):
  195. pwm_3=1500
  196. elif(pwm_3>1750):
  197. pwm_3=1750
  198. pwm.setServoPulse(14,pwm_3) #底部馬達置中
  199. time.sleep(0.5)
  200. over_time+=1
  201. if(value>ot):
  202. ot=value
  203. print(over_time,ot)
  204. over_shoot=float(abs(ot-system.target))/320
  205. print("overshoot",str(over_shoot))
  206. if over_shoot>=0 and over_shoot < 0.0625:
  207. print('overshoot success')
  208. S_ = N_STATES[4]
  209. R = 5
  210. elif (0.0625 <= over_shoot) and (over_shoot <1):
  211. S_ = N_STATES[5]
  212. R = -1*over_shoot
  213. else:
  214. S_ = N_STATES[0]
  215. R = 0
  216. return S_, R
  217. def train3(controller,system,num_iterations):
  218. global x
  219. global S_,R
  220. errors=[]
  221. cont=0
  222. setingtime=0
  223. con=0
  224. print("3")
  225. for _ in range(num_iterations):
  226. cont=cont+1
  227. current = x
  228. error = system.target - current # 真實訊號
  229. output=controller.control(error)
  230. current+= output
  231. pwm_3=0.6156*current+1396
  232. if(pwm_3<1500):
  233. pwm_3=1500
  234. elif(pwm_3>1750):
  235. pwm_3=1750
  236. pwm.setServoPulse(14,pwm_3) #底部馬達置中
  237. time.sleep(0.5)
  238. print(cont,error)
  239. if ((-1*error)>=0) and ((-1*error)<=5) and (con==0):
  240. setingtime =cont
  241. con=1
  242. elif(con==0):
  243. setingtime=40
  244. print(setingtime)
  245. if setingtime>=0 and setingtime < 10:
  246. S_ = N_STATES[6]
  247. R = 10
  248. print('setingtime success')
  249. with open('pid.txt', 'a') as f:
  250. f.write('kp:')
  251. f.write(str(controller.Kp))
  252. f.write('ki:')
  253. f.write(str(controller.Ki))
  254. f.write('kd:')
  255. f.write(str(controller.Kd))
  256. f.write('\r\n')
  257. elif (10 <= setingtime) and (setingtime <= 40):
  258. S_ = N_STATES[7]
  259. R = 1.5 -1*(error/100)
  260. else:
  261. S_ = N_STATES[4]
  262. R = -1
  263. return S_, R
  264. def build_q_table(n_states, actions):
  265. try:
  266. table = pd.read_csv("/home/pi/pid.csv",index_col=0)
  267. except:
  268. table = pd.DataFrame(
  269. np.zeros((len(n_states), len(actions))), # q_table initial values
  270. columns=actions, index=n_states, # actions's name
  271. )
  272. print(table) # show table
  273. return table
  274. def choose_action(state, q_table):
  275. # This is how to choose an action
  276. state_actions = q_table.loc[state, :]
  277. if (np.random.uniform() > EPSILON) or ((state_actions == 0).all()): # act non-greedy or state-action have no value
  278. ACT = ['kp+1', 'kp+0.1', 'kp+0.01', 'kp+0', 'kp-0.01', 'kp-0.1', 'kp-1']
  279. action_name = np.random.choice(ACT)
  280. else: # act greedy
  281. action_name = state_actions.idxmax() # replace argmax to idxmax as argmax means a different function in newer version of pandas
  282. return action_name
  283. def choose_action1(state, q_table):
  284. # This is how to choose an action
  285. state_actions = q_table.loc[state, :]
  286. if (np.random.uniform() > EPSILON) or ((state_actions == 0).all()): # act non-greedy or state-action have no value
  287. ACT = ['kp+1', 'kp+0.1', 'kp+0.01', 'kp+0', 'kp-0.01', 'kp-0.1', 'kp-1']
  288. action_name = np.random.choice(ACT)
  289. else: # act greedy
  290. action_name = state_actions.idxmax() # replace argmax to idxmax as argmax means a different function in newer version of pandas
  291. return action_name
  292. def choose_action2(state, q_table):
  293. # This is how to choose an action
  294. state_actions = q_table.loc[state, :]
  295. print("3")
  296. if (np.random.uniform() > EPSILON) or ((state_actions == 0).all()): # act non-greedy or state-action have no value
  297. ACT = [ 'ki+0.1', 'ki+0.01', 'ki+0', 'ki-0.1', 'ki-0.01']
  298. action_name = np.random.choice(ACT)
  299. else: # act greedy
  300. action_name = state_actions.idxmax() # replace argmax to idxmax as argmax means a different function in newer version of pandas
  301. print(action_name)
  302. return action_name
  303. def pid(kp):
  304. global goal,count
  305. global S_
  306. R=0
  307. print("raisetime")
  308. pid_controller = PIDController(kp,0.0,0.0)
  309. any_system = Any_System(goal)
  310. S_,R = train(pid_controller, any_system,count)
  311. print('kp:',kp)
  312. return S_,R
  313. def pid1(kp):
  314. global S_
  315. R=0
  316. print("overshoot")
  317. pid_controller = PIDController(kp, 0.0, 0.0)
  318. any_system = Any_System(goal)
  319. S_, R = train2(pid_controller, any_system, count)
  320. print('kp:', kp)
  321. return S_,R
  322. def pid2(kp,ki,kd):
  323. global S_
  324. R=0
  325. print("setingtime")
  326. pid_controller = PIDController(kp,ki,kd)
  327. any_system = Any_System(goal)
  328. S_, R = train3(pid_controller, any_system, count)
  329. print('kp:', kp,'ki',ki,'kd',kd)
  330. return S_,R
  331. def get_env_feedback(S, A):
  332. # This is how agent will interact with the environment
  333. global kp,S_
  334. R=0
  335. if A == 'kp+1': # move right
  336. kp+=1
  337. S_,R=pid(kp)
  338. elif A == 'kp+0.1': # move right
  339. kp+=0.1
  340. S_,R=pid(kp)
  341. elif A == 'kp+0.01': # move right
  342. kp+=0.01
  343. S_,R=pid(kp)
  344. elif A=='kp+0':
  345. kp=kp+0
  346. S_,R= pid(kp)
  347. elif A == 'kp-0.01': # move right
  348. kp-=0.01
  349. S_,R=pid(kp)
  350. elif A == 'kp-0.1': # move right
  351. kp-=0.1
  352. S_,R=pid(kp)
  353. elif A == 'kp-1':
  354. kp-=1
  355. S_,R= pid(kp)
  356. return S_, R
  357. def get_env_feedback1(S, A):
  358. # This is how agent will interact with the environment
  359. global kp,S_
  360. R=0
  361. if A == 'kp+1': # move right
  362. kp+=1
  363. S_,R=pid1(kp)
  364. elif A == 'kp+0.1': # move right
  365. kp+=0.1
  366. S_,R=pid1(kp)
  367. elif A == 'kp+0.01': # move right
  368. kp+=0.01
  369. S_,R=pid1(kp)
  370. elif A=='kp+0':
  371. kp=kp+0
  372. S_,R= pid1(kp)
  373. elif A == 'kp-0.01': # move right
  374. kp-=0.01
  375. S_,R=pid1(kp)
  376. elif A == 'kp-0.1': # move right
  377. kp-=0.1
  378. S_,R=pid1(kp)
  379. elif A == 'kp-1': # move right
  380. kp-=1
  381. S_,R=pid1(kp)
  382. return S_, R
  383. def get_env_feedback2(S, A):
  384. # This is how agent will interact with the environment
  385. global ki
  386. global kp
  387. global kd,S_
  388. R=0
  389. if A == 'ki+0.1': # move right
  390. ki+=0.1
  391. S_,R=pid2(kp,ki,kd)
  392. elif A == 'ki+0.01': # move right
  393. ki+=0.01
  394. S_,R=pid2(kp,ki,kd)
  395. elif A=='ki+0':
  396. ki=ki+0
  397. S_,R= pid2(kp,ki,kd)
  398. elif A == 'ki-0.1': # move right
  399. ki-=0.1
  400. S_,R=pid2(kp,ki,kd)
  401. elif A == 'ki-0.01': # move right
  402. ki-=0.01
  403. S_,R=pid2(kp,ki,kd)
  404. elif A == 'kd+0.1': # move right
  405. kd+=0.1
  406. S_,R=pid2(kp,ki,kd)
  407. elif A=='kd+0':
  408. kd=kd+0
  409. S_,R= pid2(kp,ki,kd)
  410. elif A == 'kd-0.1': # move right
  411. kd-=0.1
  412. S_,R=pid2(kp,ki,kd)
  413. return S_, R
  414. def update_env(S, episode, step_counter):
  415. # This is how environment be updated
  416. interaction = 'Episode %s: raise_time= %s' % (episode + 1,S)
  417. #print('\r{}'.format(interaction), end='')
  418. #print('Episode %s: raise_time= %s\r\n' % (episode + 1,S))
  419. def rl():
  420. global x,y,z
  421. # main part of RL loop
  422. q_table = build_q_table(N_STATES, ACTIONS)
  423. for episode in range(MAX_EPISODES):
  424. S = N_STATES[3]
  425. is_terminated = False
  426. pwm.setServoPulse(14, 1600) # 底部馬達置中
  427. time.sleep(5)
  428. while not is_terminated:
  429. #update_env(S, episode, step_counter)
  430. if S==N_STATES[3] or S==N_STATES[2] or S==N_STATES[1] or S==N_STATES[0]:
  431. pwm.setServoPulse(14, 1600) # 底部馬達置中
  432. time.sleep(3)
  433. A = choose_action(S, q_table)
  434. S_, R = get_env_feedback(S, A) # take action & get next state and reward
  435. q_predict = q_table.loc[S, A]
  436. q_target = R + GAMMA * q_table.loc[S_, :].max() # next state is not terminal
  437. q_table.loc[S, A] += ALPHA * (q_target - q_predict) # update
  438. print(q_table)
  439. S = S_ # move to next state
  440. #update_env(S, episode, step_counter)
  441. #step_counter += 1
  442. if S==N_STATES[0]:
  443. S=N_STATES[5]
  444. elif S == N_STATES[4] or S == N_STATES[5]:
  445. pwm.setServoPulse(14, 1600) # 底部馬達置中
  446. time.sleep(3)
  447. A = choose_action1(S, q_table)
  448. S_, R = get_env_feedback1(S, A) # take action & get next state and reward
  449. q_predict = q_table.loc[S, A]
  450. print(q_table)
  451. q_target = R + GAMMA * q_table.loc[S_, :].max() # next state is not terminal
  452. q_table.loc[S, A] += ALPHA * (q_target - q_predict) # update
  453. S = S_ # move to next state
  454. #update_env(S, episode, step_counter)
  455. #step_counter += 1
  456. if S==N_STATES[4]:
  457. S=N_STATES[7]
  458. elif S == N_STATES[6] or S == N_STATES[7] :
  459. pwm.setServoPulse(14, 1600) # 底部馬達置中
  460. print("ok")
  461. time.sleep(3)
  462. A = choose_action2(S, q_table)
  463. S_, R = get_env_feedback2(S, A) # take action & get next state and reward
  464. q_predict = q_table.loc[S, A]
  465. print(q_table)
  466. q_target = R + GAMMA * q_table.loc[S_, :].max() # next state is not terminal
  467. q_table.loc[S, A] += ALPHA * (q_target - q_predict) # update
  468. S = S_ # move to next state
  469. if S == N_STATES[6]:
  470. is_terminated = True
  471. #update_env(S, episode, step_counter )
  472. return q_table
  473. def coordinate_pwm_pid():
  474. global p_prev_error,p_integral,x
  475. # ----------------------------------------
  476. #計算各關節角度
  477. while True:
  478. current= x
  479. p_error=320-current
  480. p_output=2.0599999999999987*p_error+ 0.1*p_integral+0.00*(p_error - p_prev_error)
  481. p_integral +=p_error
  482. p_prev_error=p_error
  483. current+= p_output
  484. pwm_1=0.6156*current+1396
  485. print(pwm_1)
  486. if(pwm_1<1500):
  487. pwm_1=1500
  488. elif(pwm_1>1750):
  489. pwm_1=1750
  490. pwm.setServoPulse(14,pwm_1)
  491. time.sleep(0.1)
  492. if __name__ == "__main__":
  493. pwm = PCA9685(0x60, debug=False)
  494. pwm.setPWMFreq(50)
  495. #board = Board()
  496. #dis_min = 0 #Minimum ranging threshold: 0mm
  497. #dis_max = 4500 #Highest ranging threshold: 4500mm
  498. #board.set_dis_range(dis_min, dis_max)
  499. client = mqtt.Client()
  500. client.username_pw_set(username='aisky-client', password='aiskyc')
  501. client.connect("60.250.156.234", 1883, 60)
  502. client.on_connect = on_connect
  503. client.on_message = on_message
  504. client.loop_start()
  505. #a = raw_input("input:")
  506. a="t"
  507. if (a=="a"):
  508. q_table = rl()
  509. print('\r\nQ-table:\n')
  510. print(q_table)
  511. q_table.to_csv("/home/pi/pid.csv")
  512. else:
  513. coordinate_pwm_pid()