|
@@ -0,0 +1,316 @@
|
|
|
+#coding=utf-8
|
|
|
+
|
|
|
+import os
|
|
|
+import numpy as np
|
|
|
+import tensorflow as tf
|
|
|
+import datetime
|
|
|
+import cv2
|
|
|
+import imutils
|
|
|
+import time
|
|
|
+import os.path
|
|
|
+
|
|
|
+from xml.dom import minidom
|
|
|
+from PIL import Image
|
|
|
+
|
|
|
+# -------------------------------------------
|
|
|
+extract_to = "C:\\Users\\User\\Desktop\\tfcoffebean\\allpic\\"
|
|
|
+dataset_images = "C:\\Users\\User\\Desktop\\tfcoffebean\\allpic\\"
|
|
|
+dataset_labels = "C:\\Users\\User\\Desktop\\tfcoffebean\\allpic\\"
|
|
|
+folderCharacter = "/" # \\ is for windows
|
|
|
+xml_file = "C:\\Users\\User\\Desktop\\tfcoffebean\\xml_file.txt"
|
|
|
+object_xml_file = "C:\\Users\\User\\Desktop\\tfcoffebean\\xml_object.txt"
|
|
|
+
|
|
|
+def chkEnv():
|
|
|
+ if not os.path.exists(extract_to):
|
|
|
+ os.makedirs(extract_to)
|
|
|
+ print("no {} folder, created.".format(extract_to))
|
|
|
+ if (not os.path.exists(dataset_images)):
|
|
|
+ print("There is no such folder {}".format(dataset_images))
|
|
|
+ quit()
|
|
|
+ if (not os.path.exists(dataset_labels)):
|
|
|
+ print("There is no such folder {}".format(dataset_labels))
|
|
|
+ quit()
|
|
|
+ if (not os.path.exists(xml_file)):
|
|
|
+ f = open(xml_file, 'w')
|
|
|
+ f.close()
|
|
|
+ print("There is no xml file in {} ,created.".format(xml_file))
|
|
|
+ if (not os.path.exists(object_xml_file)):
|
|
|
+ f = open(object_xml_file, 'w')
|
|
|
+ f.close()
|
|
|
+ print("There is no object xml file in {} ,created.".format(object_xml_file))
|
|
|
+
|
|
|
+def getLabels(imgFile, xmlFile):
|
|
|
+ labelXML = minidom.parse(xmlFile)
|
|
|
+ labelName = []
|
|
|
+ labelXmin = []
|
|
|
+ labelYmin = []
|
|
|
+ labelXmax = []
|
|
|
+ labelYmax = []
|
|
|
+ totalW = 0
|
|
|
+ totalH = 0
|
|
|
+ countLabels = 0
|
|
|
+
|
|
|
+ tmpArrays = labelXML.getElementsByTagName("name")
|
|
|
+ for elem in tmpArrays:
|
|
|
+ labelName.append(str(elem.firstChild.data))
|
|
|
+ tmpArrays = labelXML.getElementsByTagName("xmin")
|
|
|
+ for elem in tmpArrays:
|
|
|
+ labelXmin.append(int(elem.firstChild.data))
|
|
|
+ tmpArrays = labelXML.getElementsByTagName("ymin")
|
|
|
+ for elem in tmpArrays:
|
|
|
+ labelYmin.append(int(elem.firstChild.data))
|
|
|
+ tmpArrays = labelXML.getElementsByTagName("xmax")
|
|
|
+ for elem in tmpArrays:
|
|
|
+ labelXmax.append(int(elem.firstChild.data))
|
|
|
+ tmpArrays = labelXML.getElementsByTagName("ymax")
|
|
|
+ for elem in tmpArrays:
|
|
|
+ labelYmax.append(int(elem.firstChild.data))
|
|
|
+
|
|
|
+ return labelName, labelXmin, labelYmin, labelXmax, labelYmax
|
|
|
+
|
|
|
+def write_lale_images(label, img, saveto, filename):
|
|
|
+ writePath = extract_to + label
|
|
|
+ print("WRITE:", writePath)
|
|
|
+ if not os.path.exists(writePath):
|
|
|
+ os.makedirs(writePath)
|
|
|
+
|
|
|
+ cv2.imwrite(writePath + folderCharacter + filename, img)
|
|
|
+
|
|
|
+def extract_lab_to_imgs():
|
|
|
+ chkEnv()
|
|
|
+ i = 0
|
|
|
+ for file in os.listdir(dataset_images):
|
|
|
+ filename, file_extension = os.path.splitext(file)
|
|
|
+ file_extension = file_extension.lower()
|
|
|
+ if (
|
|
|
+ file_extension == ".jpg" or file_extension == ".jpeg" or file_extension == ".png" or file_extension == ".bmp"):
|
|
|
+ print("Processing: ", dataset_images + file)
|
|
|
+ if not os.path.exists(dataset_labels + filename + ".xml"):
|
|
|
+ print("Cannot find the file {} for the image.".format(dataset_labels + filename + ".xml"))
|
|
|
+ else:
|
|
|
+ image_path = dataset_images + file
|
|
|
+ xml_path = dataset_labels + filename + ".xml"
|
|
|
+ labelName, labelXmin, labelYmin, labelXmax, labelYmax = getLabels(image_path, xml_path)
|
|
|
+ orgImage = cv2.imread(image_path)
|
|
|
+ image = orgImage.copy()
|
|
|
+ for id, label in enumerate(labelName):
|
|
|
+ cv2.rectangle(image, (labelXmin[id], labelYmin[id]), (labelXmax[id], labelYmax[id]), (0, 255, 0), 2)
|
|
|
+ label_area = orgImage[labelYmin[id]:labelYmax[id], labelXmin[id]:labelXmax[id]]
|
|
|
+ label_img_filename = filename + "_" + str(id) + ".png"
|
|
|
+ write_lale_images(labelName[id], label_area, extract_to, label_img_filename)
|
|
|
+ #cv2.imshow("Image", imutils.resize(image, width=700))
|
|
|
+ #k = cv2.waitKey(1)
|
|
|
+
|
|
|
+def turn_image():
|
|
|
+ dir1 = 'C:\\Users\\User\\Desktop\\tfcoffebean\\allpic\\brokenbean\\'
|
|
|
+ save_path1 = "C:\\Users\\User\\Desktop\\tfcoffebean\\brokenbean\\"
|
|
|
+
|
|
|
+ dir2 = 'C:\\Users\\User\\Desktop\\tfcoffebean\\allpic\\peaberrybean\\'
|
|
|
+ save_path2 = "C:\\Users\\User\\Desktop\\tfcoffebean\\peaberrybean\\"
|
|
|
+
|
|
|
+ dir3 = 'C:\\Users\\User\\Desktop\\tfcoffebean\\allpic\\\shellbean\\'
|
|
|
+ save_path3 = "C:\\Users\\User\\Desktop\\tfcoffebean\\\shellbean\\"
|
|
|
+
|
|
|
+ dir4 = 'C:\\Users\\User\\Desktop\\tfcoffebean\\allpic\\\wormbean\\'
|
|
|
+ save_path4 = "C:\\Users\\User\\Desktop\\tfcoffebean\\\wormbean\\"
|
|
|
+
|
|
|
+ x = [dir1, dir2, dir3, dir4]
|
|
|
+ y = [save_path1, save_path2, save_path3, save_path4]
|
|
|
+
|
|
|
+ i = 0
|
|
|
+ c = -1
|
|
|
+ for dir in x:
|
|
|
+ dirListing = os.listdir(dir)
|
|
|
+
|
|
|
+ c = c + 1
|
|
|
+ for filename in os.listdir(dir):
|
|
|
+ img = Image.open(dir + "/" + filename)
|
|
|
+ i = i + 1
|
|
|
+ img_rotate0 = img.transpose(Image.FLIP_LEFT_RIGHT)
|
|
|
+ img_rotate1 = img.rotate(90)
|
|
|
+ img_rotate2 = img.rotate(180)
|
|
|
+ img_rotate3 = img.rotate(270)
|
|
|
+
|
|
|
+ temp = filename[0: filename.find('_')]
|
|
|
+ temp2 = int(temp) + 1200000
|
|
|
+ temp3 = str(temp2) + '_' + str(i) + filename[filename.find('_'):]
|
|
|
+ img_rotate0.save(y[c] + temp3)
|
|
|
+ temp3 = str(temp2) + '_' + str(i + len(dirListing)) + filename[filename.find('_'):]
|
|
|
+ img_rotate1.save(y[c] + temp3)
|
|
|
+ temp3 = str(temp2) + '_' + str(i + len(dirListing) + len(dirListing)) + filename[filename.find('_'):]
|
|
|
+ img_rotate2.save(y[c] + temp3)
|
|
|
+ temp3 = str(temp2) + '_' + str(i + len(dirListing) + len(dirListing) + len(dirListing)) + filename[
|
|
|
+ filename.find(
|
|
|
+ '_'):]
|
|
|
+ img_rotate3.save(y[c] + temp3)
|
|
|
+ print(y[c])
|
|
|
+ print(len(dirListing))
|
|
|
+ print('done')
|
|
|
+
|
|
|
+def changename():
|
|
|
+ start = datetime.datetime.now()
|
|
|
+ path_name1 = 'C:\\Users\\User\\Desktop\\tfcoffebean\\brokenbean\\'
|
|
|
+ path_name2 = 'C:\\Users\\User\\Desktop\\tfcoffebean\\peaberrybean\\'
|
|
|
+ path_name3 = 'C:\\Users\\User\\Desktop\\tfcoffebean\\shellbean\\'
|
|
|
+ path_name4 = 'C:\\Users\\User\\Desktop\\tfcoffebean\\wormbean\\'
|
|
|
+
|
|
|
+ save_path = 'C:\\Users\\User\\Desktop\\tfcoffebean\\data\\'
|
|
|
+ image_size = 150
|
|
|
+ a = 1
|
|
|
+ for item in os.listdir(path_name1):
|
|
|
+ image_source = cv2.imread(path_name1 + item)
|
|
|
+ image = cv2.resize(image_source, (image_size, image_size), 0, 0, cv2.INTER_LINEAR)
|
|
|
+ cv2.imwrite(save_path + ('0' + '_' + str(a)) + ".png", image)
|
|
|
+ # os.rename(os.path.join(path_name1,item),os.path.join(path_name1,('0'+'_'+str(a)+'.png')))#os.path.join(path_name,item)表示找到每个文件的绝对路径并进行拼接操作
|
|
|
+ a += 1
|
|
|
+ b = 1
|
|
|
+ for item in os.listdir(path_name2):
|
|
|
+ image_source = cv2.imread(path_name2 + item)
|
|
|
+ image = cv2.resize(image_source, (image_size, image_size), 0, 0, cv2.INTER_LINEAR)
|
|
|
+ cv2.imwrite(save_path + ('1' + '_' + str(b)) + ".png", image)
|
|
|
+ # os.rename(os.path.join(path_name2,item),os.path.join(path_name2,('1'+'_'+str(b)+'.png')))#os.path.join(path_name,item)表示找到每个文件的绝对路径并进行拼接操作
|
|
|
+ b += 1
|
|
|
+ c = 1
|
|
|
+ for item in os.listdir(path_name3): #
|
|
|
+ image_source = cv2.imread(path_name3 + item)
|
|
|
+ image = cv2.resize(image_source, (image_size, image_size), 0, 0, cv2.INTER_LINEAR)
|
|
|
+ cv2.imwrite(save_path + ('2' + '_' + str(c)) + ".png", image)
|
|
|
+ # os.rename(os.path.join(path_name3,item),os.path.join(path_name3,('2'+'_'+str(c)+'.png')))#os.path.join(path_name,item)表示找到每个文件的绝对路径并进行拼接操作
|
|
|
+ c += 1
|
|
|
+ d = 1
|
|
|
+ for item in os.listdir(path_name4):
|
|
|
+ image_source = cv2.imread(path_name4 + item)
|
|
|
+ image = cv2.resize(image_source, (image_size, image_size), 0, 0, cv2.INTER_LINEAR)
|
|
|
+ cv2.imwrite(save_path + ('3' + '_' + str(d)) + ".png", image)
|
|
|
+ # os.rename(os.path.join(path_name4,item),os.path.join(path_name4,('3'+'_'+str(d)+'.png')))#os.path.join(path_name,item)表示找到每个文件的绝对路径并进行拼接操作
|
|
|
+ d += 1
|
|
|
+
|
|
|
+ end = datetime.datetime.now()
|
|
|
+ print("更名執行時間:", end - start)
|
|
|
+
|
|
|
+def cnn():
|
|
|
+ # data file
|
|
|
+ data_dir = "data"
|
|
|
+
|
|
|
+ # train or test
|
|
|
+ train = True
|
|
|
+ # model address
|
|
|
+ model_path = "model/image_model"
|
|
|
+
|
|
|
+ def read_data(data_dir):
|
|
|
+ datas = []
|
|
|
+ labels = []
|
|
|
+ fpaths = []
|
|
|
+ for fname in os.listdir(data_dir):
|
|
|
+ fpath = os.path.join(data_dir, fname)
|
|
|
+ fpaths.append(fpath)
|
|
|
+ image = Image.open(fpath)
|
|
|
+ data = np.array(image) / 255.0
|
|
|
+ label = int(fname.split("_")[0])
|
|
|
+
|
|
|
+ datas.append(data)
|
|
|
+ labels.append(label)
|
|
|
+
|
|
|
+ datas = np.array(datas)
|
|
|
+ labels = np.array(labels)
|
|
|
+
|
|
|
+ print("shape of datas: {}\tshape of labels: {}".format(datas.shape, labels.shape))
|
|
|
+ return fpaths, datas, labels
|
|
|
+
|
|
|
+ fpaths, datas, labels = read_data(data_dir)
|
|
|
+
|
|
|
+ # num_classes = len(set(labels))
|
|
|
+ num_classes = 4
|
|
|
+
|
|
|
+ datas_placeholder = tf.placeholder(tf.float32, [None, 150, 150, 3])
|
|
|
+
|
|
|
+ labels_placeholder = tf.placeholder(tf.int32, [None])
|
|
|
+
|
|
|
+ dropout_placeholdr = tf.placeholder(tf.float32)
|
|
|
+
|
|
|
+ conv0 = tf.layers.conv2d(datas_placeholder, 20, 5, activation=tf.nn.relu)
|
|
|
+ pool0 = tf.layers.max_pooling2d(conv0, [2, 2], [2, 2])
|
|
|
+
|
|
|
+ conv1 = tf.layers.conv2d(pool0, 40, 4, activation=tf.nn.relu)
|
|
|
+ pool1 = tf.layers.max_pooling2d(conv1, [2, 2], [2, 2])
|
|
|
+
|
|
|
+ flatten = tf.layers.flatten(pool1)
|
|
|
+
|
|
|
+ fc = tf.layers.dense(flatten, 400, activation=tf.nn.relu)
|
|
|
+
|
|
|
+ dropout_fc = tf.layers.dropout(fc, dropout_placeholdr)
|
|
|
+
|
|
|
+ logits = tf.layers.dense(dropout_fc, num_classes)
|
|
|
+
|
|
|
+ predicted_labels = tf.arg_max(logits, 1)
|
|
|
+
|
|
|
+ losses = tf.nn.softmax_cross_entropy_with_logits(
|
|
|
+ labels=tf.one_hot(labels_placeholder, num_classes),
|
|
|
+ logits=logits
|
|
|
+ )
|
|
|
+ mean_loss = tf.reduce_mean(losses)
|
|
|
+
|
|
|
+ optimizer = tf.train.AdamOptimizer(learning_rate=1e-2).minimize(losses)
|
|
|
+
|
|
|
+ saver = tf.train.Saver()
|
|
|
+
|
|
|
+ with tf.Session() as sess:
|
|
|
+
|
|
|
+ if train:
|
|
|
+ start = datetime.datetime.now()
|
|
|
+ print("train mode")
|
|
|
+
|
|
|
+ sess.run(tf.global_variables_initializer())
|
|
|
+
|
|
|
+ train_feed_dict = {
|
|
|
+ datas_placeholder: datas,
|
|
|
+ labels_placeholder: labels,
|
|
|
+ dropout_placeholdr: 0.25
|
|
|
+ }
|
|
|
+ for step in range(500):
|
|
|
+ _, mean_loss_val = sess.run([optimizer, mean_loss], feed_dict=train_feed_dict)
|
|
|
+
|
|
|
+ if step % 5 == 0:
|
|
|
+ print("step = {}\tmean loss = {}".format(step, mean_loss_val))
|
|
|
+ saver.save(sess, model_path)
|
|
|
+ print("train done,save model{}".format(model_path))
|
|
|
+ end = datetime.datetime.now()
|
|
|
+ print("訓練執行時間:", end - start)
|
|
|
+ else:
|
|
|
+
|
|
|
+ print("test mode")
|
|
|
+ saver.restore(sess, model_path)
|
|
|
+ print("{}reload model".format(model_path))
|
|
|
+
|
|
|
+ label_name_dict = {
|
|
|
+ 0: "Brokenbeans",
|
|
|
+ 1: "Peaberry",
|
|
|
+ 2: "shellbean",
|
|
|
+ 3: "Worms"
|
|
|
+ }
|
|
|
+ test_feed_dict = {
|
|
|
+ datas_placeholder: datas,
|
|
|
+ labels_placeholder: labels,
|
|
|
+ dropout_placeholdr: 0
|
|
|
+ }
|
|
|
+ predicted_labels_val = sess.run(predicted_labels, feed_dict=test_feed_dict)
|
|
|
+
|
|
|
+ for fpath, real_label, predicted_label in zip(fpaths, labels, predicted_labels_val):
|
|
|
+ # real_label_name = label_name_dict[real_label]
|
|
|
+ predicted_label_name = label_name_dict[predicted_label]
|
|
|
+ print("{}\t => {}".format(fpath, predicted_label_name))
|
|
|
+ dirListing = os.listdir(data_dir)
|
|
|
+ print(len(dirListing))
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == '__main__':
|
|
|
+ start = datetime.datetime.now()
|
|
|
+ extract_lab_to_imgs()
|
|
|
+ end = datetime.datetime.now()
|
|
|
+ print("xml轉換img執行時間:", end - start)
|
|
|
+ start1 = datetime.datetime.now()
|
|
|
+ turn_image()
|
|
|
+ end1 = datetime.datetime.now()
|
|
|
+ print("轉動圖片執行時間:", end1 - start1)
|
|
|
+ changename()
|
|
|
+ cnn()
|