STEPCTQL.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516
  1. #!/usr/bin/env python
  2. # -*- coding: UTF-8 -*-
  3. """
  4. A simple example for Reinforcement Learning using table lookup Q-learning method.
  5. An agent "o" is on the left of a 1 dimensional world, the treasure is on the rightmost location.
  6. Run this program and to see how the agent will improve its strategy of finding the treasure.
  7. View more on my tutorial page: https://morvanzhou.github.io/tutorials/
  8. """
  9. import numpy as np
  10. import pandas as pd
  11. import time
  12. import smbus
  13. import math
  14. import sympy
  15. from sympy import asin, cos, sin, acos,tan ,atan
  16. #from DFRobot_RaspberryPi_A02YYUW import DFRobot_A02_Distance as Board
  17. import paho.mqtt.client as mqtt
  18. import json
  19. import RPi.GPIO as GPIO
  20. from time import sleep
  21. np.random.seed(2) # reproducible
  22. #設定步進馬達
  23. GPIO.setmode(GPIO.BCM)
  24. DIR = 23
  25. STEP = 24
  26. CW = 1
  27. CCW = 0
  28. GPIO.setup(DIR, GPIO.OUT)
  29. GPIO.setup(STEP, GPIO.OUT)
  30. GPIO.output(DIR, CW)
  31. N_STATES = ['raise_time<1','1<=raise_time<2','2<=raise_time<4','raise_time>4','0<=overshoot<0.33','0.33<overshoot<1','10<=setingtime<20','20<=setingtime<30'] ## 1:time<5 2:5.05<time<5.25 3:5.25<time<5.5 4:time>5.5
  32. goal=320 #goal
  33. ACTIONS = ['kp+1','kp+0.1','kp+0.01', 'kp+0','kp-0.01','kp-0.1','kp-1','ki+0.1','ki+0.01', 'ki+0','ki-0.01','ki-0.1','kd+0.1', 'kd+0','kd-0.1'] # available actions
  34. EPSILON = 0.9 # greedy police
  35. ALPHA = 0.1 # learning rate
  36. GAMMA = 0.9 # discount factor
  37. MAX_EPISODES =1 # maximum episodes
  38. FRESH_TIME = 0.1 # fresh time for one move
  39. kp=0.0
  40. ki=0.0
  41. kd=0.0
  42. count=50
  43. x=0
  44. jsonmsg=""
  45. S_=""
  46. def step(CW,plus):
  47. global DIR,STEP
  48. GPIO.output(DIR, CW)
  49. for x in range(plus):
  50. # Set one coil winding to high
  51. GPIO.output(STEP,GPIO.HIGH)
  52. # Allow it to get there.
  53. sleep(0.001) # Dictates how fast stepper motor will run
  54. # Set coil winding to low
  55. GPIO.output(STEP,GPIO.LOW)
  56. sleep(0.001)
  57. def on_connect(client, userdata, flags, rc):
  58. print("Connected with result code " + str(rc))
  59. client.subscribe("b8:27:eb:eb:21:13/Log", qos=2)
  60. # 當接收到從伺服器發送的訊息時要進行的動作
  61. def on_message(client, userdata, msg):
  62. # 轉換編碼utf-8才看得懂中文
  63. global jsonmsg
  64. global x
  65. msg.payload = msg.payload.decode('utf-8')
  66. jsonmsg = json.loads(msg.payload)
  67. x=int(jsonmsg['x'])
  68. class PIDController:
  69. def __init__(self, Kp, Ki, Kd):
  70. self.Kp = Kp
  71. self.Ki = Ki
  72. self.Kd = Kd
  73. self.last_error = 0
  74. self.integral = 0
  75. def control(self, error):
  76. output = self.Kp * error + self.Ki * self.integral + self.Kd * (error - self.last_error)
  77. self.integral += error
  78. self.last_error = error
  79. return output
  80. class Any_System:
  81. def __init__(self, goal):
  82. self.target = goal
  83. self.current = 0
  84. def update(self, control_singal):
  85. self.current += control_singal
  86. return self.current
  87. def get_error(self):
  88. return self.target - self.current
  89. def train(controller,system,num_iterations):
  90. global jsonmsg
  91. global x
  92. global S_,R
  93. errors=[]
  94. raise_time=0
  95. cont=0
  96. rtz=0
  97. for _ in range(num_iterations):
  98. #error = system.get_error()
  99. raise_time+=1
  100. current=x
  101. error=system.target-current # 真實訊號
  102. output=controller.control(error)
  103. output=output*4
  104. print(raise_time,current,output)
  105. if abs(rtz)<7750:
  106. if output>0:
  107. step(1,int(output))
  108. rtz-=int(output)
  109. else:
  110. step(0,int(-1*output))
  111. rtz+=int(-1*output)
  112. time.sleep(0.5)
  113. #control_signal=controller.control(error)
  114. #current=system.update(control_signal)
  115. if ((current-system.target)>=0):
  116. cont=raise_time
  117. break
  118. else:
  119. cont=50
  120. if cont<=10:
  121. S_= N_STATES[0]
  122. R=5
  123. print('raise_time success')
  124. elif (10<cont) and (cont<=20):
  125. S_= N_STATES[1]
  126. R=((15-cont)/10)
  127. elif (20< cont) and (cont <= 40):
  128. S_= N_STATES[2]
  129. R=((30-cont)/10)
  130. else:
  131. S_= N_STATES[3]
  132. R=-(error/100)
  133. if rtz>0:
  134. step(1,int(rtz))
  135. else:
  136. step(0,int(-1*rtz))
  137. time.sleep(3)
  138. print("rtz:",rtz)
  139. return S_,R
  140. def train2(controller,system,num_iterations):
  141. global x
  142. global S_,R
  143. errors=[]
  144. current_arr=[]
  145. ot=0
  146. over_time=0
  147. over_shoot=0
  148. rtz=0
  149. for _ in range(num_iterations):
  150. #error = system.get_error()
  151. #current = 23.5 - (board.getDistance() / 10)
  152. current = x
  153. value=current
  154. error = system.target - current # 真實訊號
  155. output=controller.control(error)
  156. output=output*4
  157. if abs(rtz)<7750:
  158. if output>0:
  159. step(1,int(output))
  160. rtz-=int(output)
  161. else:
  162. step(0,int(-1*output))
  163. rtz+=int(-1*output)
  164. time.sleep(0.5)
  165. over_time+=1
  166. if(value>ot):
  167. ot=value
  168. print(over_time,ot,output,rtz)
  169. over_shoot=float(abs(ot-system.target))/320
  170. print("overshoot",str(over_shoot))
  171. if over_shoot>=0 and over_shoot < 0.0625:
  172. print('overshoot success')
  173. S_ = N_STATES[4]
  174. R = 5
  175. elif (0.0625 <= over_shoot) and (over_shoot <1):
  176. S_ = N_STATES[5]
  177. R = -1*over_shoot
  178. else:
  179. S_ = N_STATES[0]
  180. R = 0
  181. if rtz>0:
  182. step(1,int(rtz))
  183. else:
  184. step(0,int(-1*rtz))
  185. time.sleep(3)
  186. print("rtz:",rtz)
  187. return S_, R
  188. def train3(controller,system,num_iterations):
  189. global x
  190. global S_,R
  191. errors=[]
  192. cont=0
  193. setingtime=0
  194. con=0
  195. rtz=0
  196. print("3")
  197. for _ in range(num_iterations):
  198. cont=cont+1
  199. current = x
  200. error = system.target - current # 真實訊號
  201. output=controller.control(error)
  202. output=output*4
  203. if abs(rtz)<7750:
  204. if output>0:
  205. step(1,int(output))
  206. rtz-=int(output)
  207. else:
  208. step(0,int(-1*output))
  209. rtz+=int(-1*output)
  210. time.sleep(0.5)
  211. print(cont,current,output)
  212. if ((-1*error)>=0) and ((-1*error)<=5) and (con==0):
  213. setingtime =cont
  214. con=1
  215. elif(con==0):
  216. setingtime=40
  217. print(setingtime)
  218. if setingtime>=0 and setingtime < 10:
  219. S_ = N_STATES[6]
  220. R = 10
  221. print('setingtime success')
  222. with open('pid.txt', 'a') as f:
  223. f.write('kp:')
  224. f.write(str(controller.Kp))
  225. f.write('ki:')
  226. f.write(str(controller.Ki))
  227. f.write('kd:')
  228. f.write(str(controller.Kd))
  229. f.write('\r\n')
  230. elif (10 <= setingtime) and (setingtime <= 40):
  231. S_ = N_STATES[7]
  232. R = ((25-setingtime)/10)
  233. else:
  234. S_ = N_STATES[4]
  235. R = -(error/100)
  236. if rtz>0:
  237. step(1,int(rtz))
  238. else:
  239. step(0,int(-1*rtz))
  240. time.sleep(3)
  241. print("rtz:",rtz)
  242. return S_, R
  243. def build_q_table(n_states, actions):
  244. try:
  245. table = pd.read_csv("/home/pi/pid.csv",index_col=0)
  246. except:
  247. table = pd.DataFrame(
  248. np.zeros((len(n_states), len(actions))), # q_table initial values
  249. columns=actions, index=n_states, # actions's name
  250. )
  251. print(table) # show table
  252. return table
  253. def choose_action(state, q_table):
  254. # This is how to choose an action
  255. state_actions = q_table.loc[state, :]
  256. if (np.random.uniform() > EPSILON) or ((state_actions == 0).all()): # act non-greedy or state-action have no value
  257. ACT = ['kp+1', 'kp+0.1', 'kp+0.01', 'kp+0', 'kp-0.01', 'kp-0.1', 'kp-1']
  258. action_name = np.random.choice(ACT)
  259. else: # act greedy
  260. action_name = state_actions.idxmax() # replace argmax to idxmax as argmax means a different function in newer version of pandas
  261. return action_name
  262. def choose_action1(state, q_table):
  263. # This is how to choose an action
  264. state_actions = q_table.loc[state, :]
  265. if (np.random.uniform() > EPSILON) or ((state_actions == 0).all()): # act non-greedy or state-action have no value
  266. ACT = ['kp+1', 'kp+0.1', 'kp+0.01', 'kp+0', 'kp-0.01', 'kp-0.1', 'kp-1']
  267. action_name = np.random.choice(ACT)
  268. else: # act greedy
  269. action_name = state_actions.idxmax() # replace argmax to idxmax as argmax means a different function in newer version of pandas
  270. return action_name
  271. def choose_action2(state, q_table):
  272. # This is how to choose an action
  273. state_actions = q_table.loc[state, :]
  274. print("3")
  275. if (np.random.uniform() > EPSILON) or ((state_actions == 0).all()): # act non-greedy or state-action have no value
  276. ACT = [ 'ki+0.1', 'ki+0.01', 'ki+0', 'ki-0.1', 'ki-0.01']
  277. action_name = np.random.choice(ACT)
  278. else: # act greedy
  279. action_name = state_actions.idxmax() # replace argmax to idxmax as argmax means a different function in newer version of pandas
  280. print(action_name)
  281. return action_name
  282. def pid(kp):
  283. global goal,count
  284. global S_
  285. R=0
  286. print("raisetime")
  287. pid_controller = PIDController(kp,0.0,0.0)
  288. any_system = Any_System(goal)
  289. S_,R = train(pid_controller, any_system,count)
  290. print('kp:',kp)
  291. return S_,R
  292. def pid1(kp):
  293. global S_
  294. R=0
  295. print("overshoot")
  296. pid_controller = PIDController(kp, 0.0, 0.0)
  297. any_system = Any_System(goal)
  298. S_, R = train2(pid_controller, any_system, count)
  299. print('kp:', kp)
  300. return S_,R
  301. def pid2(kp,ki,kd):
  302. global S_
  303. R=0
  304. print("setingtime")
  305. pid_controller = PIDController(kp,ki,kd)
  306. any_system = Any_System(goal)
  307. S_, R = train3(pid_controller, any_system, count)
  308. print('kp:', kp,'ki',ki,'kd',kd)
  309. return S_,R
  310. def get_env_feedback(S, A):
  311. # This is how agent will interact with the environment
  312. global kp,S_
  313. R=0
  314. if A == 'kp+1': # move right
  315. kp+=1
  316. S_,R=pid(kp)
  317. elif A == 'kp+0.1': # move right
  318. kp+=0.1
  319. S_,R=pid(kp)
  320. elif A == 'kp+0.01': # move right
  321. kp+=0.01
  322. S_,R=pid(kp)
  323. elif A=='kp+0':
  324. kp=kp+0
  325. S_,R= pid(kp)
  326. elif A == 'kp-0.01': # move right
  327. kp-=0.01
  328. S_,R=pid(kp)
  329. elif A == 'kp-0.1': # move right
  330. kp-=0.1
  331. S_,R=pid(kp)
  332. elif A == 'kp-1':
  333. kp-=1
  334. S_,R= pid(kp)
  335. return S_, R
  336. def get_env_feedback1(S, A):
  337. # This is how agent will interact with the environment
  338. global kp,S_
  339. R=0
  340. if A == 'kp+1': # move right
  341. kp+=1
  342. S_,R=pid1(kp)
  343. elif A == 'kp+0.1': # move right
  344. kp+=0.1
  345. S_,R=pid1(kp)
  346. elif A == 'kp+0.01': # move right
  347. kp+=0.01
  348. S_,R=pid1(kp)
  349. elif A=='kp+0':
  350. kp=kp+0
  351. S_,R= pid1(kp)
  352. elif A == 'kp-0.01': # move right
  353. kp-=0.01
  354. S_,R=pid1(kp)
  355. elif A == 'kp-0.1': # move right
  356. kp-=0.1
  357. S_,R=pid1(kp)
  358. elif A == 'kp-1': # move right
  359. kp-=1
  360. S_,R=pid1(kp)
  361. return S_, R
  362. def get_env_feedback2(S, A):
  363. # This is how agent will interact with the environment
  364. global ki
  365. global kp
  366. global kd,S_
  367. R=0
  368. if A == 'ki+0.1': # move right
  369. ki+=0.1
  370. S_,R=pid2(kp,ki,kd)
  371. elif A == 'ki+0.01': # move right
  372. ki+=0.01
  373. S_,R=pid2(kp,ki,kd)
  374. elif A=='ki+0':
  375. ki=ki+0
  376. S_,R= pid2(kp,ki,kd)
  377. elif A == 'ki-0.1': # move right
  378. ki-=0.1
  379. S_,R=pid2(kp,ki,kd)
  380. elif A == 'ki-0.01': # move right
  381. ki-=0.01
  382. S_,R=pid2(kp,ki,kd)
  383. elif A == 'kd+0.1': # move right
  384. kd+=0.1
  385. S_,R=pid2(kp,ki,kd)
  386. elif A=='kd+0':
  387. kd=kd+0
  388. S_,R= pid2(kp,ki,kd)
  389. elif A == 'kd-0.1': # move right
  390. kd-=0.1
  391. S_,R=pid2(kp,ki,kd)
  392. return S_, R
  393. def update_env(S, episode, step_counter):
  394. # This is how environment be updated
  395. interaction = 'Episode %s: raise_time= %s' % (episode + 1,S)
  396. #print('\r{}'.format(interaction), end='')
  397. #print('Episode %s: raise_time= %s\r\n' % (episode + 1,S))
  398. def rl():
  399. global x,y,z
  400. # main part of RL loop
  401. q_table = build_q_table(N_STATES, ACTIONS)
  402. for episode in range(MAX_EPISODES):
  403. S = N_STATES[3]
  404. is_terminated = False
  405. step(1,7750)
  406. time.sleep(5)
  407. while not is_terminated:
  408. #update_env(S, episode, step_counter)
  409. if S==N_STATES[3] or S==N_STATES[2] or S==N_STATES[1] or S==N_STATES[0]:
  410. A = choose_action(S, q_table)
  411. S_, R = get_env_feedback(S, A) # take action & get next state and reward
  412. q_predict = q_table.loc[S, A]
  413. q_target = R + GAMMA * q_table.loc[S_, :].max() # next state is not terminal
  414. q_table.loc[S, A] += ALPHA * (q_target - q_predict) # update
  415. print(q_table)
  416. S = S_ # move to next state
  417. #update_env(S, episode, step_counter)
  418. #step_counter += 1
  419. if S==N_STATES[0]:
  420. S=N_STATES[5]
  421. elif S == N_STATES[4] or S == N_STATES[5]:
  422. A = choose_action1(S, q_table)
  423. S_, R = get_env_feedback1(S, A) # take action & get next state and reward
  424. q_predict = q_table.loc[S, A]
  425. print(q_table)
  426. q_target = R + GAMMA * q_table.loc[S_, :].max() # next state is not terminal
  427. q_table.loc[S, A] += ALPHA * (q_target - q_predict) # update
  428. S = S_ # move to next state
  429. #update_env(S, episode, step_counter)
  430. #step_counter += 1
  431. if S==N_STATES[4]:
  432. S=N_STATES[7]
  433. elif S == N_STATES[6] or S == N_STATES[7] :
  434. A = choose_action2(S, q_table)
  435. S_, R = get_env_feedback2(S, A) # take action & get next state and reward
  436. q_predict = q_table.loc[S, A]
  437. print(q_table)
  438. q_target = R + GAMMA * q_table.loc[S_, :].max() # next state is not terminal
  439. q_table.loc[S, A] += ALPHA * (q_target - q_predict) # update
  440. S = S_ # move to next state
  441. if S == N_STATES[6]:
  442. is_terminated = True
  443. #update_env(S, episode, step_counter )
  444. return q_table
  445. def test_pid():
  446. global x
  447. step(1,7750)
  448. time.sleep(5)
  449. p_prev_error=0
  450. p_integral=0
  451. # ----------------------------------------
  452. #計算各關節角度
  453. while True:
  454. current= x
  455. p_error=320-current
  456. p_output=5.32*p_error+ 0.0*p_integral+0.00*(p_error - p_prev_error)
  457. p_integral +=p_error
  458. p_prev_error=p_error
  459. if p_output>0:
  460. step(1,int(p_output))
  461. else:
  462. step(0,int(-1*p_output))
  463. time.sleep(0.1)
  464. if __name__ == "__main__":
  465. client = mqtt.Client()
  466. client.username_pw_set(username='aisky-client', password='aiskyc')
  467. client.connect("60.250.156.234", 1883, 60)
  468. client.on_connect = on_connect
  469. client.on_message = on_message
  470. client.loop_start()
  471. a="T"
  472. if (a=="a"):
  473. q_table = rl()
  474. print('\r\nQ-table:\n')
  475. print(q_table)
  476. q_table.to_csv("/home/pi/pid.csv")
  477. else:
  478. test_pid()