TQL.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583
  1. #!/usr/bin/env python
  2. # -*- coding: UTF-8 -*-
  3. """
  4. A simple example for Reinforcement Learning using table lookup Q-learning method.
  5. An agent "o" is on the left of a 1 dimensional world, the treasure is on the rightmost location.
  6. Run this program and to see how the agent will improve its strategy of finding the treasure.
  7. View more on my tutorial page: https://morvanzhou.github.io/tutorials/
  8. """
  9. import numpy as np
  10. import pandas as pd
  11. import time
  12. import smbus
  13. import math
  14. import sympy
  15. from sympy import asin, cos, sin, acos,tan ,atan
  16. from DFRobot_RaspberryPi_A02YYUW import DFRobot_A02_Distance as Board
  17. np.random.seed(2) # reproducible
  18. N_STATES = ['raise_time<1','1<=raise_time<2','2<=raise_time<4','raise_time>4','0<=overshoot<0.33','0.33<overshoot<1','10<=setingtime<20','20<=setingtime<30'] ## 1:time<5 2:5.05<time<5.25 3:5.25<time<5.5 4:time>5.5
  19. goal=16 #goal
  20. ACTIONS = ['kp+1','kp+0.1','kp+0.01', 'kp+0','kp-0.01','kp-0.1','kp-1','ki+0.1','ki+0.01', 'ki+0','ki-0.01','ki-0.1','kd+0.01', 'kd+0','kd-0.01'] # available actions
  21. EPSILON = 0.9 # greedy police
  22. ALPHA = 0.1 # learning rate
  23. GAMMA = 0.9 # discount factor
  24. MAX_EPISODES =1 # maximum episodes
  25. FRESH_TIME = 0.1 # fresh time for one move
  26. kp=0.0
  27. ki=0.0
  28. kd=0.0
  29. count=50
  30. y=23.5
  31. class PCA9685:
  32. # Registers/etc.
  33. __SUBADR1 = 0x02
  34. __SUBADR2 = 0x03
  35. __SUBADR3 = 0x04
  36. __MODE1 = 0x00
  37. __PRESCALE = 0xFE
  38. __LED0_ON_L = 0x06
  39. __LED0_ON_H = 0x07
  40. __LED0_OFF_L = 0x08
  41. __LED0_OFF_H = 0x09
  42. __ALLLED_ON_L = 0xFA
  43. __ALLLED_ON_H = 0xFB
  44. __ALLLED_OFF_L = 0xFC
  45. __ALLLED_OFF_H = 0xFD
  46. def __init__(self, address=0x60, debug=False):
  47. self.bus = smbus.SMBus(1)
  48. self.address = address
  49. self.debug = debug
  50. if (self.debug):
  51. print("Reseting PCA9685")
  52. self.write(self.__MODE1, 0x00)
  53. def write(self, reg, value):
  54. "Writes an 8-bit value to the specified register/address"
  55. self.bus.write_byte_data(self.address, reg, value)
  56. if (self.debug):
  57. print("I2C: Write 0x%02X to register 0x%02X" % (value, reg))
  58. def read(self, reg):
  59. "Read an unsigned byte from the I2C device"
  60. result = self.bus.read_byte_data(self.address, reg)
  61. if (self.debug):
  62. print("I2C: Device 0x%02X returned 0x%02X from reg 0x%02X" % (self.address, result & 0xFF, reg))
  63. return result
  64. def setPWMFreq(self, freq):
  65. "Sets the PWM frequency"
  66. prescaleval = 25000000.0 # 25MHz
  67. prescaleval /= 4096.0 # 12-bit
  68. prescaleval /= float(freq)
  69. prescaleval -= 1.0
  70. if (self.debug):
  71. print("Setting PWM frequency to %d Hz" % freq)
  72. print("Estimated pre-scale: %d" % prescaleval)
  73. prescale = math.floor(prescaleval + 0.5)
  74. if (self.debug):
  75. print("Final pre-scale: %d" % prescale)
  76. oldmode = self.read(self.__MODE1);
  77. newmode = (oldmode & 0x7F) | 0x10 # sleep
  78. self.write(self.__MODE1, newmode) # go to sleep
  79. self.write(self.__PRESCALE, int(math.floor(prescale)))
  80. self.write(self.__MODE1, oldmode)
  81. time.sleep(0.005)
  82. self.write(self.__MODE1, oldmode | 0x80)
  83. def setPWM(self, channel, on, off):
  84. "Sets a single PWM channel"
  85. self.write(self.__LED0_ON_L + 4 * channel, on & 0xFF)
  86. self.write(self.__LED0_ON_H + 4 * channel, on >> 8)
  87. self.write(self.__LED0_OFF_L + 4 * channel, off & 0xFF)
  88. self.write(self.__LED0_OFF_H + 4 * channel, off >> 8)
  89. if (self.debug):
  90. print("channel: %d LED_ON: %d LED_OFF: %d" % (channel, on, off))
  91. def setServoPulse(self, channel, pulse):
  92. "Sets the Servo Pulse,The PWM frequency must be 50HZ"
  93. pulse = pulse * 4096 / 20000 # PWM frequency is 50HZ,the period is 20000us
  94. self.setPWM(channel, 0, int(pulse))
  95. class PIDController:
  96. def __init__(self, Kp, Ki, Kd):
  97. self.Kp = Kp
  98. self.Ki = Ki
  99. self.Kd = Kd
  100. self.last_error = 0
  101. self.integral = 0
  102. def control(self, error):
  103. output = self.Kp * error + self.Ki * self.integral + self.Kd * (error - self.last_error)
  104. self.integral += error
  105. self.last_error = error
  106. return output
  107. def ctr_pwm(self,x,y,z):
  108. pwm_1=1750
  109. i = 0
  110. j = 0
  111. # ----------------------------------------
  112. r1 = 8
  113. r2 = 8
  114. # 計算各關節角度
  115. x = (x - 320) / 100
  116. y = float(y)
  117. z = ((240 - z) / 100 )+ 5.3
  118. enable = 1
  119. print("x=", x)
  120. if (x < -19 or x > 19): # 超出範圍離開
  121. print("x over range")
  122. enable = 0
  123. # break
  124. # y= input("input y (0~19):")
  125. print("y=", y)
  126. if (y < 0 or y > 19): # 超出範圍離開
  127. print("y over range")
  128. enable = 0
  129. # break
  130. # z= input("input z (3~11):")
  131. print("z=", z)
  132. z2 = z - 3
  133. if (z2 < 0 or z2 > 8): # 超出範圍離開
  134. print("z over range")
  135. enable = 0
  136. L = math.sqrt(x * x + y * y) - 8 # 實際夾子前端馬達6.5右邊馬達為1.5公分
  137. print("L=", L)
  138. if (L < 2): # 超出範圍離開
  139. print("L over range")
  140. enable = 0
  141. h1 = L * L + z2 * z2 + 2 * r1 * L + r1 * r1 - r2 * r2
  142. # print("h1=",h1)
  143. h2 = -4 * r1 * z2
  144. # print("h2=",h2)
  145. h3 = L * L + z2 * z2 - 2 * r1 * L + r1 * r1 - r2 * r2
  146. # print("h3=",h3)
  147. try:
  148. Theta = 2 * atan((-h2 + math.sqrt(h2 * h2 - 4 * h1 * h3)) / (2 * h1)) # 無法得出角度離開
  149. # print("Theta=",Theta,math.degrees(Theta))
  150. except:
  151. print("error")
  152. enable = 0
  153. try:
  154. Gamma = asin((z2 - r1 * sin(Theta)) / r2) # 無法得出角度離開
  155. # print("Gamma=",Gamma,math.degrees(Gamma))
  156. except:
  157. print("error")
  158. enable = 0
  159. try:
  160. tt = x / (r1 * cos(Theta) + r2 * cos(Gamma) + 8)
  161. print(tt)
  162. if (tt >= 1):
  163. tt = 1
  164. if (tt <= -1):
  165. tt = -1
  166. Beta = acos(tt) # 無法得出角度離開
  167. # print("Beta=",Beta,math.degrees(Beta))
  168. except:
  169. print("error")
  170. enable = 0
  171. if (enable == 1):
  172. # Gamma角度轉換為PWM(右邊馬達)
  173. pwm_3 = -12.222 * (-math.degrees(Gamma)) + 2300
  174. print("pwm_3", pwm_3)
  175. error=pwm_3-pwm_1
  176. print("error",error)
  177. output = self.Kp * error + self.Ki * self.integral + self.Kd * (error - self.last_error)
  178. self.integral += error
  179. self.last_error = error
  180. pwm_1+=output*10
  181. if pwm_1>1700 and pwm_1<2300:
  182. pwm_1=pwm_1
  183. elif pwm_1>2300:
  184. pwm_1=2300
  185. print("output kp ki kd",pwm_1,self.Kp,self.Ki,self.Kd)
  186. pwm.setServoPulse(0,2300) #前面馬達開夾子
  187. pwm.setServoPulse(1,1750) #右邊馬達45度
  188. pwm.setServoPulse(14,1600) #底部馬達置中
  189. pwm.setServoPulse(15, pwm_1)
  190. class Any_System:
  191. def __init__(self, goal):
  192. self.target = goal
  193. self.current = 0
  194. def update(self, control_singal):
  195. self.current += control_singal
  196. return self.current
  197. def get_error(self):
  198. return self.target - self.current
  199. def train(S,controller,system,num_iterations):
  200. errors=[]
  201. raise_time=0
  202. for _ in range(num_iterations):
  203. #error = system.get_error()
  204. current =23.5-(board.getDistance()/10)
  205. with open('dis1.txt', 'a') as f:
  206. f.write(str(current))
  207. f.write('\r\n')
  208. error=system.target-current # 真實訊號
  209. controller.ctr_pwm(320,current,240)
  210. #control_signal=controller.control(error)
  211. #current=system.update(control_signal)
  212. errors.append(error)
  213. #time.sleep(0.1)
  214. raise_time+=1
  215. S_ = N_STATES[3]
  216. if error >= 0 and error <= system.target:
  217. R = 0
  218. elif error > system.target:
  219. R = -1
  220. elif error < 0:
  221. R = -1
  222. print(raise_time,current)
  223. if current>system.target:
  224. print('time',raise_time)
  225. if raise_time<10:
  226. S_= N_STATES[0]
  227. R=5
  228. elif (10<=raise_time) and (raise_time<20):
  229. S_= N_STATES[1]
  230. R=3
  231. elif (20<= raise_time) and (raise_time < 40):
  232. S_= N_STATES[2]
  233. R=2
  234. else:
  235. S_= N_STATES[3]
  236. if error>0 and error < system.target:
  237. R = 0
  238. elif error >system.target:
  239. R = -1
  240. elif error <0:
  241. R=-1
  242. return S_, R
  243. return S_,R
  244. def train2(S,controller,system,num_iterations):
  245. errors=[]
  246. current_arr=[]
  247. overshoot_value=[]
  248. for _ in range(num_iterations):
  249. #error = system.get_error()
  250. current = 23.5 - (board.getDistance() / 10)
  251. with open('dis2.txt', 'a') as f:
  252. f.write(str(current))
  253. f.write('\r\n')
  254. error = system.target - current # 真實訊號
  255. controller.ctr_pwm(320,current,240)
  256. #control_signal=controller.control(error)
  257. #current=system.update(control_signal)
  258. errors.append(error)
  259. #time.sleep(0.1)
  260. current_arr.append(current)
  261. for i in range(num_iterations):
  262. if (current_arr[i]-system.target>=0):
  263. overshoot_value.append((current_arr[i] - system.target) / system.target)
  264. print(i,current_arr[i])
  265. #min(temp_arr[9:19])
  266. #print(min(temp_arr[9:19]))
  267. #overshoot=abs((min(temp_arr[9:19])-30)/30)
  268. try:
  269. overshoot=max(overshoot_value)
  270. except:
  271. overshoot =1
  272. print(overshoot)
  273. if overshoot>=0 and overshoot < 0.0625:
  274. print('overshoot success')
  275. S_ = N_STATES[4]
  276. R = 2
  277. elif (0.0625 <= overshoot) and (overshoot < 1):
  278. S_ = N_STATES[5]
  279. R = 1
  280. else:
  281. S_ = N_STATES[0]
  282. R = 0
  283. return S_, R
  284. def train3(S,controller,system,num_iterations):
  285. errors=[]
  286. current_arr=[]
  287. for _ in range(num_iterations):
  288. #error = system.get_error()
  289. current = 23.5 - (board.getDistance() / 10)
  290. with open('dis3.txt', 'a') as f:
  291. f.write(str(current))
  292. f.write('\r\n')
  293. error = system.target - current # 真實訊號
  294. controller.ctr_pwm(320,current,240)
  295. #control_signal=controller.control(error)
  296. #current=system.update(control_signal)
  297. errors.append(error)
  298. #time.sleep(0.1)
  299. current_arr.append(current)
  300. if (abs(current_arr[10]-system.target)) < 5:
  301. setingtime =10
  302. elif (abs(current_arr[20]-system.target)) < 5:
  303. setingtime =20
  304. elif (abs(current_arr[30]-system.target)) < 5:
  305. setingtime =30
  306. else:
  307. setingtime=31
  308. for i in range(9,49):
  309. if (abs(current_arr[i] - system.target))>5:
  310. setingtime=31
  311. print(setingtime)
  312. if setingtime>=10 and setingtime < 20:
  313. S_ = N_STATES[6]
  314. R = 2
  315. print('setingtime success')
  316. with open('pid.txt', 'a') as f:
  317. f.write('kp:')
  318. f.write(str(controller.Kp))
  319. f.write('ki:')
  320. f.write(str(controller.Ki))
  321. f.write('kd:')
  322. f.write(str(controller.Kd))
  323. f.write('\r\n')
  324. elif (20 <= setingtime) and (setingtime < 30):
  325. S_ = N_STATES[7]
  326. R = 1
  327. else:
  328. S_ = N_STATES[4]
  329. R = 0
  330. return S_, R
  331. def build_q_table(n_states, actions):
  332. try:
  333. table = pd.read_csv("/home/pi/pid.csv",index_col=0)
  334. except:
  335. table = pd.DataFrame(
  336. np.zeros((len(n_states), len(actions))), # q_table initial values
  337. columns=actions, index=n_states, # actions's name
  338. )
  339. print(table) # show table
  340. return table
  341. def choose_action(state, q_table):
  342. # This is how to choose an action
  343. state_actions = q_table.loc[state, :]
  344. if (np.random.uniform() > EPSILON) or ((state_actions == 0).all()): # act non-greedy or state-action have no value
  345. ACT = ['kp+1', 'kp+0.1', 'kp+0.01', 'kp+0', 'kp-0.01', 'kp-0.1', 'kp-1']
  346. action_name = np.random.choice(ACT)
  347. else: # act greedy
  348. action_name = state_actions.idxmax() # replace argmax to idxmax as argmax means a different function in newer version of pandas
  349. return action_name
  350. def choose_action1(state, q_table):
  351. # This is how to choose an action
  352. state_actions = q_table.loc[state, :]
  353. if (np.random.uniform() > EPSILON) or ((state_actions == 0).all()): # act non-greedy or state-action have no value
  354. ACT = ['kp+1', 'kp+0.1', 'kp+0.01', 'kp+0', 'kp-0.01', 'kp-0.1', 'kp-1']
  355. action_name = np.random.choice(ACT)
  356. else: # act greedy
  357. action_name = state_actions.idxmax() # replace argmax to idxmax as argmax means a different function in newer version of pandas
  358. return action_name
  359. def choose_action2(state, q_table):
  360. # This is how to choose an action
  361. state_actions = q_table.loc[state, :]
  362. if (np.random.uniform() > EPSILON) or ((state_actions == 0).all()): # act non-greedy or state-action have no value
  363. ACT = [ 'ki+0.1', 'ki+0.01', 'ki+0', 'ki-0.01', 'ki-0.1','kd+0.01', 'kd+0','kd-0.01']
  364. action_name = np.random.choice(ACT)
  365. else: # act greedy
  366. action_name = state_actions.idxmax() # replace argmax to idxmax as argmax means a different function in newer version of pandas
  367. return action_name
  368. def pid(S,kp):
  369. global goal,count
  370. print('kp:',kp)
  371. pid_controller = PIDController(kp,0.0,0.0)
  372. any_system = Any_System(goal)
  373. S_,R = train(S,pid_controller, any_system,count)
  374. return S_,R
  375. def pid1(S,kp):
  376. print("overshoot")
  377. pid_controller = PIDController(kp, 0.0, 0.0)
  378. any_system = Any_System(goal)
  379. S_, R = train2(S, pid_controller, any_system, count)
  380. print('kp:', kp)
  381. return S_,R
  382. def pid2(S,kp,ki,kd):
  383. print("setingtime")
  384. pid_controller = PIDController(kp,ki,kd)
  385. any_system = Any_System(goal)
  386. S_, R = train3(S, pid_controller, any_system, count)
  387. print('kp:', kp,'ki',ki,'kd',kd)
  388. return S_,R
  389. def get_env_feedback(S, A):
  390. # This is how agent will interact with the environment
  391. global kp
  392. if A == 'kp+1': # move right
  393. kp+=1
  394. S_,R=pid(S,kp)
  395. elif A == 'kp+0.1': # move right
  396. kp+=0.1
  397. S_,R=pid(S,kp)
  398. elif A == 'kp+0.01': # move right
  399. kp+=0.01
  400. S_,R=pid(S,kp)
  401. elif A=='kp+0':
  402. kp=kp+0
  403. S_,R= pid(S,kp)
  404. elif A == 'kp-0.01': # move right
  405. kp-=0.01
  406. S_,R=pid(S,kp)
  407. elif A == 'kp-0.1': # move right
  408. kp-=0.1
  409. S_,R=pid(S,kp)
  410. elif A == 'kp-1':
  411. kp-=1
  412. S_,R= pid(S,kp)
  413. return S_, R
  414. def get_env_feedback1(S, A):
  415. # This is how agent will interact with the environment
  416. global kp
  417. if A == 'kp+1': # move right
  418. kp+=1
  419. S_,R=pid1(S,kp)
  420. elif A == 'kp+0.1': # move right
  421. kp+=0.1
  422. S_,R=pid1(S,kp)
  423. elif A == 'kp+0.01': # move right
  424. kp+=0.01
  425. S_,R=pid1(S,kp)
  426. elif A=='kp+0':
  427. kp=kp+0
  428. S_,R= pid1(S,kp)
  429. elif A == 'kp-0.01': # move right
  430. kp-=0.01
  431. S_,R=pid1(S,kp)
  432. elif A == 'kp-0.1': # move right
  433. kp-=0.1
  434. S_,R=pid1(S,kp)
  435. elif A == 'kp-1':
  436. kp-=1
  437. S_,R= pid1(S,kp)
  438. return S_, R
  439. def get_env_feedback2(S, A):
  440. # This is how agent will interact with the environment
  441. global ki
  442. global kp
  443. global kd
  444. if A == 'ki+0.1': # move right
  445. ki+=0.1
  446. S_,R=pid2(S,kp,ki,kd)
  447. elif A == 'ki+0.01': # move right
  448. ki+=0.01
  449. S_,R=pid2(S,kp,ki,kd)
  450. elif A=='ki+0':
  451. ki=ki+0
  452. S_,R= pid2(S,kp,ki,kd)
  453. elif A == 'ki-0.01': # move right
  454. ki-=0.01
  455. S_,R=pid2(S,kp,ki,kd)
  456. elif A == 'ki-0.1': # move right
  457. ki-=0.1
  458. S_,R=pid2(S,kp,ki,kd)
  459. elif A == 'kd+0.01': # move right
  460. kd+=0.01
  461. S_,R=pid2(S,kp,ki,kd)
  462. elif A=='kd+0':
  463. kd=kd+0
  464. S_,R= pid2(S,kp,ki,kd)
  465. elif A == 'kd-0.01': # move right
  466. kd-=0.01
  467. S_,R=pid2(S,kp,ki,kd)
  468. return S_, R
  469. def update_env(S, episode, step_counter):
  470. # This is how environment be updated
  471. interaction = 'Episode %s: raise_time= %s' % (episode + 1,S)
  472. #print('\r{}'.format(interaction), end='')
  473. #print('Episode %s: raise_time= %s\r\n' % (episode + 1,S))
  474. def rl():
  475. # main part of RL loop
  476. global x, y, z, w
  477. q_table = build_q_table(N_STATES, ACTIONS)
  478. for episode in range(MAX_EPISODES):
  479. S = N_STATES[3]
  480. is_terminated = False
  481. while not is_terminated:
  482. x = 320
  483. z = 240
  484. y = 15.5
  485. pwm.setServoPulse(0, 1100) # 前面馬達開夾子
  486. pwm.setServoPulse(1, 1750) # 右邊馬達45度
  487. pwm.setServoPulse(14, 1600) # 底部馬達置中
  488. pwm.setServoPulse(15, 1700) # 左邊馬達垂直
  489. time.sleep(1)
  490. pwm.setServoPulse(0, 0)
  491. pwm.setServoPulse(1, 0)
  492. pwm.setServoPulse(14, 0)
  493. pwm.setServoPulse(15, 0)
  494. #update_env(S, episode, step_counter)
  495. if S==N_STATES[3] or S==N_STATES[2] or S==N_STATES[1] or S==N_STATES[0]:
  496. A = choose_action(S, q_table)
  497. S_, R = get_env_feedback(S, A) # take action & get next state and reward
  498. q_predict = q_table.loc[S, A]
  499. q_target = R + GAMMA * q_table.loc[S_, :].max() # next state is not terminal
  500. q_table.loc[S, A] += ALPHA * (q_target - q_predict) # update
  501. print(q_table)
  502. S = S_ # move to next state
  503. #update_env(S, episode, step_counter)
  504. #step_counter += 1
  505. if S==N_STATES[0]:
  506. S=N_STATES[5]
  507. elif S == N_STATES[4] or S == N_STATES[5]:
  508. print("raise_time success")
  509. A = choose_action1(S, q_table)
  510. S_, R = get_env_feedback1(S, A) # take action & get next state and reward
  511. q_predict = q_table.loc[S, A]
  512. q_target = R + GAMMA * q_table.loc[S_, :].max() # next state is not terminal
  513. q_table.loc[S, A] += ALPHA * (q_target - q_predict) # update
  514. S = S_ # move to next state
  515. #update_env(S, episode, step_counter)
  516. #step_counter += 1
  517. if S==N_STATES[4]:
  518. S=N_STATES[7]
  519. elif S == N_STATES[6] or S == N_STATES[7] :
  520. A = choose_action2(S, q_table)
  521. S_, R = get_env_feedback2(S, A) # take action & get next state and reward
  522. q_predict = q_table.loc[S, A]
  523. q_target = R + GAMMA * q_table.loc[S_, :].max() # next state is not terminal
  524. q_table.loc[S, A] += ALPHA * (q_target - q_predict) # update
  525. S = S_ # move to next state
  526. if S == N_STATES[6]:
  527. is_terminated = True
  528. #update_env(S, episode, step_counter )
  529. return q_table
  530. if __name__ == "__main__":
  531. pwm = PCA9685(0x60, debug=False)
  532. pwm.setPWMFreq(50)
  533. board = Board()
  534. dis_min = 0 #Minimum ranging threshold: 0mm
  535. dis_max = 4500 #Highest ranging threshold: 4500mm
  536. board.set_dis_range(dis_min, dis_max)
  537. q_table = rl()
  538. print('\r\nQ-table:\n')
  539. print(q_table)
  540. q_table.to_csv("/home/pi/pid.csv")