Coffee_project_train.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316
  1. #coding=utf-8
  2. import os
  3. import numpy as np
  4. import tensorflow as tf
  5. import datetime
  6. import cv2
  7. import imutils
  8. import time
  9. import os.path
  10. from xml.dom import minidom
  11. from PIL import Image
  12. # -------------------------------------------
  13. extract_to = "C:\\Users\\User\\Desktop\\tfcoffebean\\allpic\\"
  14. dataset_images = "C:\\Users\\User\\Desktop\\tfcoffebean\\allpic\\"
  15. dataset_labels = "C:\\Users\\User\\Desktop\\tfcoffebean\\allpic\\"
  16. folderCharacter = "/" # \\ is for windows
  17. xml_file = "C:\\Users\\User\\Desktop\\tfcoffebean\\xml_file.txt"
  18. object_xml_file = "C:\\Users\\User\\Desktop\\tfcoffebean\\xml_object.txt"
  19. def chkEnv():
  20. if not os.path.exists(extract_to):
  21. os.makedirs(extract_to)
  22. print("no {} folder, created.".format(extract_to))
  23. if (not os.path.exists(dataset_images)):
  24. print("There is no such folder {}".format(dataset_images))
  25. quit()
  26. if (not os.path.exists(dataset_labels)):
  27. print("There is no such folder {}".format(dataset_labels))
  28. quit()
  29. if (not os.path.exists(xml_file)):
  30. f = open(xml_file, 'w')
  31. f.close()
  32. print("There is no xml file in {} ,created.".format(xml_file))
  33. if (not os.path.exists(object_xml_file)):
  34. f = open(object_xml_file, 'w')
  35. f.close()
  36. print("There is no object xml file in {} ,created.".format(object_xml_file))
  37. def getLabels(imgFile, xmlFile):
  38. labelXML = minidom.parse(xmlFile)
  39. labelName = []
  40. labelXmin = []
  41. labelYmin = []
  42. labelXmax = []
  43. labelYmax = []
  44. totalW = 0
  45. totalH = 0
  46. countLabels = 0
  47. tmpArrays = labelXML.getElementsByTagName("name")
  48. for elem in tmpArrays:
  49. labelName.append(str(elem.firstChild.data))
  50. tmpArrays = labelXML.getElementsByTagName("xmin")
  51. for elem in tmpArrays:
  52. labelXmin.append(int(elem.firstChild.data))
  53. tmpArrays = labelXML.getElementsByTagName("ymin")
  54. for elem in tmpArrays:
  55. labelYmin.append(int(elem.firstChild.data))
  56. tmpArrays = labelXML.getElementsByTagName("xmax")
  57. for elem in tmpArrays:
  58. labelXmax.append(int(elem.firstChild.data))
  59. tmpArrays = labelXML.getElementsByTagName("ymax")
  60. for elem in tmpArrays:
  61. labelYmax.append(int(elem.firstChild.data))
  62. return labelName, labelXmin, labelYmin, labelXmax, labelYmax
  63. def write_lale_images(label, img, saveto, filename):
  64. writePath = extract_to + label
  65. print("WRITE:", writePath)
  66. if not os.path.exists(writePath):
  67. os.makedirs(writePath)
  68. cv2.imwrite(writePath + folderCharacter + filename, img)
  69. def extract_lab_to_imgs():
  70. chkEnv()
  71. i = 0
  72. for file in os.listdir(dataset_images):
  73. filename, file_extension = os.path.splitext(file)
  74. file_extension = file_extension.lower()
  75. if (
  76. file_extension == ".jpg" or file_extension == ".jpeg" or file_extension == ".png" or file_extension == ".bmp"):
  77. print("Processing: ", dataset_images + file)
  78. if not os.path.exists(dataset_labels + filename + ".xml"):
  79. print("Cannot find the file {} for the image.".format(dataset_labels + filename + ".xml"))
  80. else:
  81. image_path = dataset_images + file
  82. xml_path = dataset_labels + filename + ".xml"
  83. labelName, labelXmin, labelYmin, labelXmax, labelYmax = getLabels(image_path, xml_path)
  84. orgImage = cv2.imread(image_path)
  85. image = orgImage.copy()
  86. for id, label in enumerate(labelName):
  87. cv2.rectangle(image, (labelXmin[id], labelYmin[id]), (labelXmax[id], labelYmax[id]), (0, 255, 0), 2)
  88. label_area = orgImage[labelYmin[id]:labelYmax[id], labelXmin[id]:labelXmax[id]]
  89. label_img_filename = filename + "_" + str(id) + ".png"
  90. write_lale_images(labelName[id], label_area, extract_to, label_img_filename)
  91. #cv2.imshow("Image", imutils.resize(image, width=700))
  92. #k = cv2.waitKey(1)
  93. def turn_image():
  94. dir1 = 'C:\\Users\\User\\Desktop\\tfcoffebean\\allpic\\brokenbean\\'
  95. save_path1 = "C:\\Users\\User\\Desktop\\tfcoffebean\\brokenbean\\"
  96. dir2 = 'C:\\Users\\User\\Desktop\\tfcoffebean\\allpic\\peaberrybean\\'
  97. save_path2 = "C:\\Users\\User\\Desktop\\tfcoffebean\\peaberrybean\\"
  98. dir3 = 'C:\\Users\\User\\Desktop\\tfcoffebean\\allpic\\\shellbean\\'
  99. save_path3 = "C:\\Users\\User\\Desktop\\tfcoffebean\\\shellbean\\"
  100. dir4 = 'C:\\Users\\User\\Desktop\\tfcoffebean\\allpic\\\wormbean\\'
  101. save_path4 = "C:\\Users\\User\\Desktop\\tfcoffebean\\\wormbean\\"
  102. x = [dir1, dir2, dir3, dir4]
  103. y = [save_path1, save_path2, save_path3, save_path4]
  104. i = 0
  105. c = -1
  106. for dir in x:
  107. dirListing = os.listdir(dir)
  108. c = c + 1
  109. for filename in os.listdir(dir):
  110. img = Image.open(dir + "/" + filename)
  111. i = i + 1
  112. img_rotate0 = img.transpose(Image.FLIP_LEFT_RIGHT)
  113. img_rotate1 = img.rotate(90)
  114. img_rotate2 = img.rotate(180)
  115. img_rotate3 = img.rotate(270)
  116. temp = filename[0: filename.find('_')]
  117. temp2 = int(temp) + 1200000
  118. temp3 = str(temp2) + '_' + str(i) + filename[filename.find('_'):]
  119. img_rotate0.save(y[c] + temp3)
  120. temp3 = str(temp2) + '_' + str(i + len(dirListing)) + filename[filename.find('_'):]
  121. img_rotate1.save(y[c] + temp3)
  122. temp3 = str(temp2) + '_' + str(i + len(dirListing) + len(dirListing)) + filename[filename.find('_'):]
  123. img_rotate2.save(y[c] + temp3)
  124. temp3 = str(temp2) + '_' + str(i + len(dirListing) + len(dirListing) + len(dirListing)) + filename[
  125. filename.find(
  126. '_'):]
  127. img_rotate3.save(y[c] + temp3)
  128. print(y[c])
  129. print(len(dirListing))
  130. print('done')
  131. def changename():
  132. start = datetime.datetime.now()
  133. path_name1 = 'C:\\Users\\User\\Desktop\\tfcoffebean\\brokenbean\\'
  134. path_name2 = 'C:\\Users\\User\\Desktop\\tfcoffebean\\peaberrybean\\'
  135. path_name3 = 'C:\\Users\\User\\Desktop\\tfcoffebean\\shellbean\\'
  136. path_name4 = 'C:\\Users\\User\\Desktop\\tfcoffebean\\wormbean\\'
  137. save_path = 'C:\\Users\\User\\Desktop\\tfcoffebean\\data\\'
  138. image_size = 150
  139. a = 1
  140. for item in os.listdir(path_name1):
  141. image_source = cv2.imread(path_name1 + item)
  142. image = cv2.resize(image_source, (image_size, image_size), 0, 0, cv2.INTER_LINEAR)
  143. cv2.imwrite(save_path + ('0' + '_' + str(a)) + ".png", image)
  144. # os.rename(os.path.join(path_name1,item),os.path.join(path_name1,('0'+'_'+str(a)+'.png')))#os.path.join(path_name,item)表示找到每个文件的绝对路径并进行拼接操作
  145. a += 1
  146. b = 1
  147. for item in os.listdir(path_name2):
  148. image_source = cv2.imread(path_name2 + item)
  149. image = cv2.resize(image_source, (image_size, image_size), 0, 0, cv2.INTER_LINEAR)
  150. cv2.imwrite(save_path + ('1' + '_' + str(b)) + ".png", image)
  151. # os.rename(os.path.join(path_name2,item),os.path.join(path_name2,('1'+'_'+str(b)+'.png')))#os.path.join(path_name,item)表示找到每个文件的绝对路径并进行拼接操作
  152. b += 1
  153. c = 1
  154. for item in os.listdir(path_name3): #
  155. image_source = cv2.imread(path_name3 + item)
  156. image = cv2.resize(image_source, (image_size, image_size), 0, 0, cv2.INTER_LINEAR)
  157. cv2.imwrite(save_path + ('2' + '_' + str(c)) + ".png", image)
  158. # os.rename(os.path.join(path_name3,item),os.path.join(path_name3,('2'+'_'+str(c)+'.png')))#os.path.join(path_name,item)表示找到每个文件的绝对路径并进行拼接操作
  159. c += 1
  160. d = 1
  161. for item in os.listdir(path_name4):
  162. image_source = cv2.imread(path_name4 + item)
  163. image = cv2.resize(image_source, (image_size, image_size), 0, 0, cv2.INTER_LINEAR)
  164. cv2.imwrite(save_path + ('3' + '_' + str(d)) + ".png", image)
  165. # os.rename(os.path.join(path_name4,item),os.path.join(path_name4,('3'+'_'+str(d)+'.png')))#os.path.join(path_name,item)表示找到每个文件的绝对路径并进行拼接操作
  166. d += 1
  167. end = datetime.datetime.now()
  168. print("更名執行時間:", end - start)
  169. def cnn():
  170. # data file
  171. data_dir = "data"
  172. # train or test
  173. train = True
  174. # model address
  175. model_path = "model/image_model"
  176. def read_data(data_dir):
  177. datas = []
  178. labels = []
  179. fpaths = []
  180. for fname in os.listdir(data_dir):
  181. fpath = os.path.join(data_dir, fname)
  182. fpaths.append(fpath)
  183. image = Image.open(fpath)
  184. data = np.array(image) / 255.0
  185. label = int(fname.split("_")[0])
  186. datas.append(data)
  187. labels.append(label)
  188. datas = np.array(datas)
  189. labels = np.array(labels)
  190. print("shape of datas: {}\tshape of labels: {}".format(datas.shape, labels.shape))
  191. return fpaths, datas, labels
  192. fpaths, datas, labels = read_data(data_dir)
  193. # num_classes = len(set(labels))
  194. num_classes = 4
  195. datas_placeholder = tf.placeholder(tf.float32, [None, 150, 150, 3])
  196. labels_placeholder = tf.placeholder(tf.int32, [None])
  197. dropout_placeholdr = tf.placeholder(tf.float32)
  198. conv0 = tf.layers.conv2d(datas_placeholder, 20, 5, activation=tf.nn.relu)
  199. pool0 = tf.layers.max_pooling2d(conv0, [2, 2], [2, 2])
  200. conv1 = tf.layers.conv2d(pool0, 40, 4, activation=tf.nn.relu)
  201. pool1 = tf.layers.max_pooling2d(conv1, [2, 2], [2, 2])
  202. flatten = tf.layers.flatten(pool1)
  203. fc = tf.layers.dense(flatten, 400, activation=tf.nn.relu)
  204. dropout_fc = tf.layers.dropout(fc, dropout_placeholdr)
  205. logits = tf.layers.dense(dropout_fc, num_classes)
  206. predicted_labels = tf.arg_max(logits, 1)
  207. losses = tf.nn.softmax_cross_entropy_with_logits(
  208. labels=tf.one_hot(labels_placeholder, num_classes),
  209. logits=logits
  210. )
  211. mean_loss = tf.reduce_mean(losses)
  212. optimizer = tf.train.AdamOptimizer(learning_rate=1e-2).minimize(losses)
  213. saver = tf.train.Saver()
  214. with tf.Session() as sess:
  215. if train:
  216. start = datetime.datetime.now()
  217. print("train mode")
  218. sess.run(tf.global_variables_initializer())
  219. train_feed_dict = {
  220. datas_placeholder: datas,
  221. labels_placeholder: labels,
  222. dropout_placeholdr: 0.25
  223. }
  224. for step in range(500):
  225. _, mean_loss_val = sess.run([optimizer, mean_loss], feed_dict=train_feed_dict)
  226. if step % 5 == 0:
  227. print("step = {}\tmean loss = {}".format(step, mean_loss_val))
  228. saver.save(sess, model_path)
  229. print("train done,save model{}".format(model_path))
  230. end = datetime.datetime.now()
  231. print("訓練執行時間:", end - start)
  232. else:
  233. print("test mode")
  234. saver.restore(sess, model_path)
  235. print("{}reload model".format(model_path))
  236. label_name_dict = {
  237. 0: "Brokenbeans",
  238. 1: "Peaberry",
  239. 2: "shellbean",
  240. 3: "Worms"
  241. }
  242. test_feed_dict = {
  243. datas_placeholder: datas,
  244. labels_placeholder: labels,
  245. dropout_placeholdr: 0
  246. }
  247. predicted_labels_val = sess.run(predicted_labels, feed_dict=test_feed_dict)
  248. for fpath, real_label, predicted_label in zip(fpaths, labels, predicted_labels_val):
  249. # real_label_name = label_name_dict[real_label]
  250. predicted_label_name = label_name_dict[predicted_label]
  251. print("{}\t => {}".format(fpath, predicted_label_name))
  252. dirListing = os.listdir(data_dir)
  253. print(len(dirListing))
  254. if __name__ == '__main__':
  255. start = datetime.datetime.now()
  256. extract_lab_to_imgs()
  257. end = datetime.datetime.now()
  258. print("xml轉換img執行時間:", end - start)
  259. start1 = datetime.datetime.now()
  260. turn_image()
  261. end1 = datetime.datetime.now()
  262. print("轉動圖片執行時間:", end1 - start1)
  263. changename()
  264. cnn()