123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316 |
- #coding=utf-8
- import os
- import numpy as np
- import tensorflow as tf
- import datetime
- import cv2
- import imutils
- import time
- import os.path
- from xml.dom import minidom
- from PIL import Image
- # -------------------------------------------
- extract_to = "C:\\Users\\User\\Desktop\\tfcoffebean\\allpic\\"
- dataset_images = "C:\\Users\\User\\Desktop\\tfcoffebean\\allpic\\"
- dataset_labels = "C:\\Users\\User\\Desktop\\tfcoffebean\\allpic\\"
- folderCharacter = "/" # \\ is for windows
- xml_file = "C:\\Users\\User\\Desktop\\tfcoffebean\\xml_file.txt"
- object_xml_file = "C:\\Users\\User\\Desktop\\tfcoffebean\\xml_object.txt"
- def chkEnv():
- if not os.path.exists(extract_to):
- os.makedirs(extract_to)
- print("no {} folder, created.".format(extract_to))
- if (not os.path.exists(dataset_images)):
- print("There is no such folder {}".format(dataset_images))
- quit()
- if (not os.path.exists(dataset_labels)):
- print("There is no such folder {}".format(dataset_labels))
- quit()
- if (not os.path.exists(xml_file)):
- f = open(xml_file, 'w')
- f.close()
- print("There is no xml file in {} ,created.".format(xml_file))
- if (not os.path.exists(object_xml_file)):
- f = open(object_xml_file, 'w')
- f.close()
- print("There is no object xml file in {} ,created.".format(object_xml_file))
- def getLabels(imgFile, xmlFile):
- labelXML = minidom.parse(xmlFile)
- labelName = []
- labelXmin = []
- labelYmin = []
- labelXmax = []
- labelYmax = []
- totalW = 0
- totalH = 0
- countLabels = 0
- tmpArrays = labelXML.getElementsByTagName("name")
- for elem in tmpArrays:
- labelName.append(str(elem.firstChild.data))
- tmpArrays = labelXML.getElementsByTagName("xmin")
- for elem in tmpArrays:
- labelXmin.append(int(elem.firstChild.data))
- tmpArrays = labelXML.getElementsByTagName("ymin")
- for elem in tmpArrays:
- labelYmin.append(int(elem.firstChild.data))
- tmpArrays = labelXML.getElementsByTagName("xmax")
- for elem in tmpArrays:
- labelXmax.append(int(elem.firstChild.data))
- tmpArrays = labelXML.getElementsByTagName("ymax")
- for elem in tmpArrays:
- labelYmax.append(int(elem.firstChild.data))
- return labelName, labelXmin, labelYmin, labelXmax, labelYmax
- def write_lale_images(label, img, saveto, filename):
- writePath = extract_to + label
- print("WRITE:", writePath)
- if not os.path.exists(writePath):
- os.makedirs(writePath)
- cv2.imwrite(writePath + folderCharacter + filename, img)
- def extract_lab_to_imgs():
- chkEnv()
- i = 0
- for file in os.listdir(dataset_images):
- filename, file_extension = os.path.splitext(file)
- file_extension = file_extension.lower()
- if (
- file_extension == ".jpg" or file_extension == ".jpeg" or file_extension == ".png" or file_extension == ".bmp"):
- print("Processing: ", dataset_images + file)
- if not os.path.exists(dataset_labels + filename + ".xml"):
- print("Cannot find the file {} for the image.".format(dataset_labels + filename + ".xml"))
- else:
- image_path = dataset_images + file
- xml_path = dataset_labels + filename + ".xml"
- labelName, labelXmin, labelYmin, labelXmax, labelYmax = getLabels(image_path, xml_path)
- orgImage = cv2.imread(image_path)
- image = orgImage.copy()
- for id, label in enumerate(labelName):
- cv2.rectangle(image, (labelXmin[id], labelYmin[id]), (labelXmax[id], labelYmax[id]), (0, 255, 0), 2)
- label_area = orgImage[labelYmin[id]:labelYmax[id], labelXmin[id]:labelXmax[id]]
- label_img_filename = filename + "_" + str(id) + ".png"
- write_lale_images(labelName[id], label_area, extract_to, label_img_filename)
- #cv2.imshow("Image", imutils.resize(image, width=700))
- #k = cv2.waitKey(1)
- def turn_image():
- dir1 = 'C:\\Users\\User\\Desktop\\tfcoffebean\\allpic\\brokenbean\\'
- save_path1 = "C:\\Users\\User\\Desktop\\tfcoffebean\\brokenbean\\"
- dir2 = 'C:\\Users\\User\\Desktop\\tfcoffebean\\allpic\\peaberrybean\\'
- save_path2 = "C:\\Users\\User\\Desktop\\tfcoffebean\\peaberrybean\\"
- dir3 = 'C:\\Users\\User\\Desktop\\tfcoffebean\\allpic\\\shellbean\\'
- save_path3 = "C:\\Users\\User\\Desktop\\tfcoffebean\\\shellbean\\"
- dir4 = 'C:\\Users\\User\\Desktop\\tfcoffebean\\allpic\\\wormbean\\'
- save_path4 = "C:\\Users\\User\\Desktop\\tfcoffebean\\\wormbean\\"
- x = [dir1, dir2, dir3, dir4]
- y = [save_path1, save_path2, save_path3, save_path4]
- i = 0
- c = -1
- for dir in x:
- dirListing = os.listdir(dir)
- c = c + 1
- for filename in os.listdir(dir):
- img = Image.open(dir + "/" + filename)
- i = i + 1
- img_rotate0 = img.transpose(Image.FLIP_LEFT_RIGHT)
- img_rotate1 = img.rotate(90)
- img_rotate2 = img.rotate(180)
- img_rotate3 = img.rotate(270)
- temp = filename[0: filename.find('_')]
- temp2 = int(temp) + 1200000
- temp3 = str(temp2) + '_' + str(i) + filename[filename.find('_'):]
- img_rotate0.save(y[c] + temp3)
- temp3 = str(temp2) + '_' + str(i + len(dirListing)) + filename[filename.find('_'):]
- img_rotate1.save(y[c] + temp3)
- temp3 = str(temp2) + '_' + str(i + len(dirListing) + len(dirListing)) + filename[filename.find('_'):]
- img_rotate2.save(y[c] + temp3)
- temp3 = str(temp2) + '_' + str(i + len(dirListing) + len(dirListing) + len(dirListing)) + filename[
- filename.find(
- '_'):]
- img_rotate3.save(y[c] + temp3)
- print(y[c])
- print(len(dirListing))
- print('done')
- def changename():
- start = datetime.datetime.now()
- path_name1 = 'C:\\Users\\User\\Desktop\\tfcoffebean\\brokenbean\\'
- path_name2 = 'C:\\Users\\User\\Desktop\\tfcoffebean\\peaberrybean\\'
- path_name3 = 'C:\\Users\\User\\Desktop\\tfcoffebean\\shellbean\\'
- path_name4 = 'C:\\Users\\User\\Desktop\\tfcoffebean\\wormbean\\'
- save_path = 'C:\\Users\\User\\Desktop\\tfcoffebean\\data\\'
- image_size = 150
- a = 1
- for item in os.listdir(path_name1):
- image_source = cv2.imread(path_name1 + item)
- image = cv2.resize(image_source, (image_size, image_size), 0, 0, cv2.INTER_LINEAR)
- cv2.imwrite(save_path + ('0' + '_' + str(a)) + ".png", image)
- # os.rename(os.path.join(path_name1,item),os.path.join(path_name1,('0'+'_'+str(a)+'.png')))#os.path.join(path_name,item)表示找到每个文件的绝对路径并进行拼接操作
- a += 1
- b = 1
- for item in os.listdir(path_name2):
- image_source = cv2.imread(path_name2 + item)
- image = cv2.resize(image_source, (image_size, image_size), 0, 0, cv2.INTER_LINEAR)
- cv2.imwrite(save_path + ('1' + '_' + str(b)) + ".png", image)
- # os.rename(os.path.join(path_name2,item),os.path.join(path_name2,('1'+'_'+str(b)+'.png')))#os.path.join(path_name,item)表示找到每个文件的绝对路径并进行拼接操作
- b += 1
- c = 1
- for item in os.listdir(path_name3): #
- image_source = cv2.imread(path_name3 + item)
- image = cv2.resize(image_source, (image_size, image_size), 0, 0, cv2.INTER_LINEAR)
- cv2.imwrite(save_path + ('2' + '_' + str(c)) + ".png", image)
- # os.rename(os.path.join(path_name3,item),os.path.join(path_name3,('2'+'_'+str(c)+'.png')))#os.path.join(path_name,item)表示找到每个文件的绝对路径并进行拼接操作
- c += 1
- d = 1
- for item in os.listdir(path_name4):
- image_source = cv2.imread(path_name4 + item)
- image = cv2.resize(image_source, (image_size, image_size), 0, 0, cv2.INTER_LINEAR)
- cv2.imwrite(save_path + ('3' + '_' + str(d)) + ".png", image)
- # os.rename(os.path.join(path_name4,item),os.path.join(path_name4,('3'+'_'+str(d)+'.png')))#os.path.join(path_name,item)表示找到每个文件的绝对路径并进行拼接操作
- d += 1
- end = datetime.datetime.now()
- print("更名執行時間:", end - start)
- def cnn():
- # data file
- data_dir = "data"
- # train or test
- train = True
- # model address
- model_path = "model/image_model"
- def read_data(data_dir):
- datas = []
- labels = []
- fpaths = []
- for fname in os.listdir(data_dir):
- fpath = os.path.join(data_dir, fname)
- fpaths.append(fpath)
- image = Image.open(fpath)
- data = np.array(image) / 255.0
- label = int(fname.split("_")[0])
- datas.append(data)
- labels.append(label)
- datas = np.array(datas)
- labels = np.array(labels)
- print("shape of datas: {}\tshape of labels: {}".format(datas.shape, labels.shape))
- return fpaths, datas, labels
- fpaths, datas, labels = read_data(data_dir)
- # num_classes = len(set(labels))
- num_classes = 4
- datas_placeholder = tf.placeholder(tf.float32, [None, 150, 150, 3])
- labels_placeholder = tf.placeholder(tf.int32, [None])
- dropout_placeholdr = tf.placeholder(tf.float32)
- conv0 = tf.layers.conv2d(datas_placeholder, 20, 5, activation=tf.nn.relu)
- pool0 = tf.layers.max_pooling2d(conv0, [2, 2], [2, 2])
- conv1 = tf.layers.conv2d(pool0, 40, 4, activation=tf.nn.relu)
- pool1 = tf.layers.max_pooling2d(conv1, [2, 2], [2, 2])
- flatten = tf.layers.flatten(pool1)
- fc = tf.layers.dense(flatten, 400, activation=tf.nn.relu)
- dropout_fc = tf.layers.dropout(fc, dropout_placeholdr)
- logits = tf.layers.dense(dropout_fc, num_classes)
- predicted_labels = tf.arg_max(logits, 1)
- losses = tf.nn.softmax_cross_entropy_with_logits(
- labels=tf.one_hot(labels_placeholder, num_classes),
- logits=logits
- )
- mean_loss = tf.reduce_mean(losses)
- optimizer = tf.train.AdamOptimizer(learning_rate=1e-2).minimize(losses)
- saver = tf.train.Saver()
- with tf.Session() as sess:
- if train:
- start = datetime.datetime.now()
- print("train mode")
- sess.run(tf.global_variables_initializer())
- train_feed_dict = {
- datas_placeholder: datas,
- labels_placeholder: labels,
- dropout_placeholdr: 0.25
- }
- for step in range(500):
- _, mean_loss_val = sess.run([optimizer, mean_loss], feed_dict=train_feed_dict)
- if step % 5 == 0:
- print("step = {}\tmean loss = {}".format(step, mean_loss_val))
- saver.save(sess, model_path)
- print("train done,save model{}".format(model_path))
- end = datetime.datetime.now()
- print("訓練執行時間:", end - start)
- else:
- print("test mode")
- saver.restore(sess, model_path)
- print("{}reload model".format(model_path))
- label_name_dict = {
- 0: "Brokenbeans",
- 1: "Peaberry",
- 2: "shellbean",
- 3: "Worms"
- }
- test_feed_dict = {
- datas_placeholder: datas,
- labels_placeholder: labels,
- dropout_placeholdr: 0
- }
- predicted_labels_val = sess.run(predicted_labels, feed_dict=test_feed_dict)
- for fpath, real_label, predicted_label in zip(fpaths, labels, predicted_labels_val):
- # real_label_name = label_name_dict[real_label]
- predicted_label_name = label_name_dict[predicted_label]
- print("{}\t => {}".format(fpath, predicted_label_name))
- dirListing = os.listdir(data_dir)
- print(len(dirListing))
- if __name__ == '__main__':
- start = datetime.datetime.now()
- extract_lab_to_imgs()
- end = datetime.datetime.now()
- print("xml轉換img執行時間:", end - start)
- start1 = datetime.datetime.now()
- turn_image()
- end1 = datetime.datetime.now()
- print("轉動圖片執行時間:", end1 - start1)
- changename()
- cnn()
|