Browse Source

上傳檔案到 ''

fatwolf 3 years ago
parent
commit
40506340f9
4 changed files with 657 additions and 0 deletions
  1. 40 0
      Coffee_convert_to_csv.py
  2. 281 0
      Coffee_project_detection.py
  3. 316 0
      Coffee_project_train.py
  4. 20 0
      Coffee_resize.py

+ 40 - 0
Coffee_convert_to_csv.py

@@ -0,0 +1,40 @@
+
+import os
+import glob
+import pandas as pd
+import xml.etree.ElementTree as ET
+
+
+def xml_to_csv(path):
+    xml_list = []
+    print(xml_list)
+    for xml_file in glob.glob(path + '/*.xml'):
+
+        tree = ET.parse(xml_file)
+        root = tree.getroot()
+        for member in root.findall('object'):
+            value = (root.find('filename').text,
+                     int(root.find('size')[0].text),
+                     int(root.find('size')[1].text),
+                     member[0].text,
+                     int(member[4][0].text),
+                     int(member[4][1].text),
+                     int(member[4][2].text),
+                     int(member[4][3].text)
+                     )
+            xml_list.append(value)
+    column_name = ['filename', 'width', 'height', 'class', 'xmin', 'ymin', 'xmax', 'ymax']
+    xml_df = pd.DataFrame(xml_list, columns=column_name)
+    return xml_df
+
+
+def main():
+    for directory in ['allpic']:
+        print(directory)
+        image_path = os.path.join(os.getcwd(), 'tfcoffebean/{}'.format(directory))
+        xml_df = xml_to_csv(image_path)
+        print(image_path)
+        xml_df.to_csv('C:/Users/User/Desktop/tfcoffebean/xml_to_csv/{}_labels.csv'.format(directory), index=None)
+        print('Successfully converted xml to csv.')
+
+main()

+ 281 - 0
Coffee_project_detection.py

@@ -0,0 +1,281 @@
+#coding=utf-8
+
+import os
+import numpy as np
+import datetime
+import cv2
+import pymysql
+import time
+import tensorflow as tf
+import requests as req
+
+from numba import jit
+from urllib import parse
+from PIL import Image
+
+
+conn = pymysql.connect(host="127.0.0.1", port=3306, user='root', passwd='g53743001', db='coffee_detection',
+                               charset='utf8')
+image_size = 150
+
+cap = cv2.VideoCapture(0)
+def takephoto():
+    ret, frame = cap.read()
+    frame = cv2.resize(frame, (1936, 1096))
+    cv2.imwrite("D:\\fatwolf\\company_files\\opencv\\2.png", frame)
+    #cap.release()
+
+def cut_rectangle():
+    #img = cv2.imread("D:\\fatwolf\\company_files\\opencv\\2021-05-05-11_13_47.png")
+    img = cv2.imread("D:\\fatwolf\\company_files\\opencv\\2.png")
+
+    # img = cv2.resize(img1,(968,548))
+    # img = img.shape
+    point_color = (0, 0, 255)
+    # roi = img[421:527,328:369]
+    command1 = "SELECT Name,X, X1 ,Y ,Y1 FROM `cut` WHERE Name LIKE 'roi1'"
+    l = conn.cursor()
+    l.execute(command1)
+    conn.commit()
+    r1 = l.fetchone()
+    #print(r1[0])
+
+    command2 = "SELECT Name,X, X1 ,Y ,Y1 FROM `cut` WHERE Name LIKE 'roi2'"
+    l = conn.cursor()
+    l.execute(command2)
+    conn.commit()
+    r2 = l.fetchone()
+    #print(r2[0])
+
+    command3 = "SELECT Name,X, X1 ,Y ,Y1 FROM `cut` WHERE Name LIKE 'roi3'"
+    l = conn.cursor()
+    l.execute(command3)
+    conn.commit()
+    r3= l.fetchone()
+    #print(r3[0])
+
+    command4 = "SELECT Name,X, X1 ,Y ,Y1 FROM `cut` WHERE Name LIKE 'roi4'"
+    l = conn.cursor()
+    l.execute(command4)
+    conn.commit()
+    r4 = l.fetchone()
+    #print(r4[0])
+
+    def roi1():
+        x = r1[1]
+        x1 = r1[2]
+        y = r1[3]
+        y1 = r1[4]
+        i = 1
+        number = 1
+        # roi [y:y1,x:x1]
+        # rectangle (x.y),(x1,y1)
+        roi = img[y:y1, x:x1]
+        cv2.rectangle(img, (x, y), (x1, y1), point_color, 1)
+        roi = cv2.resize(roi,(image_size, image_size), 0, 0, cv2.INTER_LINEAR)
+        cv2.imwrite('D:\\fatwolf\\company_files\\paper_coffee\\pic\\' + '00_1.png', roi)
+        # cv2.imshow("ROI_WINDOW",roi)
+        # cv2.waitKey(0)
+
+    def roi2():
+        x = r2[1]
+        x1 = r2[2]
+        y = r2[3]
+        y1 = r2[4]
+        i = 1
+        number = 1
+        # roi [y:y1,x:x1]
+        # rectangle (x.y),(x1,y1)
+        roi = img[y:y1, x:x1]
+        cv2.rectangle(img, (x, y), (x1, y1), point_color, 1)
+        roi = cv2.resize(roi,(image_size, image_size), 0, 0, cv2.INTER_LINEAR)
+        cv2.imwrite('D:\\fatwolf\\company_files\\paper_coffee\\pic\\' + '00_2.png', roi)
+        # cv2.imshow("ROI_WINDOW",roi)
+        # cv2.waitKey(0)
+
+    def roi3():
+        x = r3[1]
+        x1 = r3[2]
+        y = r3[3]
+        y1 = r3[4]
+        i = 1
+        # roi [y:y1,x:x1]
+        # rectangle (x.y),(x1,y1)
+        roi = img[y:y1, x:x1]
+        cv2.rectangle(img, (x, y), (x1, y1), point_color, 1)
+        roi = cv2.resize(roi,(image_size, image_size), 0, 0, cv2.INTER_LINEAR)
+        cv2.imwrite('D:\\fatwolf\\company_files\\paper_coffee\\pic\\' + '00_3.png', roi)
+        # cv2.imshow("ROI_WINDOW",roi)
+        # cv2.waitKey(0)
+
+    def roi4():
+        x = r4[1]
+        x1 = r4[2]
+        y = r4[3]
+        y1 = r4[4]
+        i = 1
+        # roi [y:y1,x:x1]
+        # rectangle (x.y),(x1,y1)
+        roi = img[y:y1, x:x1]
+        cv2.rectangle(img, (x, y), (x1, y1), point_color, 1)
+        roi = cv2.resize(roi,(image_size, image_size), 0, 0, cv2.INTER_LINEAR)
+        cv2.imwrite('D:\\fatwolf\\company_files\\paper_coffee\\pic\\' + '00_4.png', roi)
+        # cv2.imshow("ROI_WINDOW",roi)
+        # cv2.waitKey(0)
+
+
+
+    start = datetime.datetime.now()
+    roi1()
+    roi2()
+    roi3()
+    roi4()
+
+    end = datetime.datetime.now()
+    print("cut_rectangle Run Time:", end - start)
+    # cv2.imshow('roi1', roi1)
+    # cv2.imshow("image", img)
+
+def cnn():
+    # data file
+    data_dir = "D:\\fatwolf\\company_files\\paper_coffee\\pic\\"
+
+    # train or test
+    train = False
+    # model address
+    model_path = "model/image_model"
+
+    def read_data(data_dir):
+        datas = []
+        labels = []
+        fpaths = []
+        for fname in os.listdir(data_dir):
+            fpath = os.path.join(data_dir, fname)
+            fpaths.append(fpath)
+            image = Image.open(fpath)
+            data = np.array(image) / 255.0
+            label = str(fname.split("_")[0])
+
+
+            datas.append(data)
+            labels.append(label)
+
+        datas = np.array(datas)
+        labels = np.array(labels)
+
+        #print("shape of datas: {}\tshape of labels: {}".format(datas.shape, labels.shape))
+        return fpaths, datas, labels
+
+    fpaths, datas, labels = read_data(data_dir)
+
+    # num_classes = len(set(labels))
+    num_classes = 4
+
+    datas_placeholder = tf.compat.v1.placeholder(tf.float32, [None, 150, 150, 3])
+
+    labels_placeholder = tf.placeholder(tf.int32, [None])
+
+    dropout_placeholdr = tf.placeholder(tf.float32)
+
+    conv0 = tf.layers.conv2d(datas_placeholder, 20, 5, activation=tf.nn.relu)
+    pool0 = tf.layers.max_pooling2d(conv0, [2, 2], [2, 2])
+
+    conv1 = tf.layers.conv2d(pool0, 40, 4, activation=tf.nn.relu)
+    pool1 = tf.layers.max_pooling2d(conv1, [2, 2], [2, 2])
+
+    flatten = tf.layers.flatten(pool1)
+
+    fc = tf.layers.dense(flatten, 400, activation=tf.nn.relu)
+
+    dropout_fc = tf.layers.dropout(fc, dropout_placeholdr)
+
+    logits = tf.layers.dense(dropout_fc, num_classes)
+
+    predicted_labels = tf.arg_max(logits, 1)
+
+    losses = tf.nn.softmax_cross_entropy_with_logits(
+        labels=tf.one_hot(labels_placeholder, num_classes),
+        logits=logits
+    )
+    mean_loss = tf.reduce_mean(losses)
+
+    optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=1e-2).minimize(losses)
+
+    saver = tf.compat.v1.train.Saver()
+
+    with tf.compat.v1.Session() as sess:
+
+        if train:
+            print("train mode")
+
+            sess.run(tf.global_variables_initializer())
+
+            train_feed_dict = {
+                datas_placeholder: datas,
+                labels_placeholder: labels,
+                dropout_placeholdr: 0.25
+            }
+            for step in range(500):
+                _, mean_loss_val = sess.run([optimizer, mean_loss], feed_dict=train_feed_dict)
+
+                if step % 50 == 0:
+                    print("step = {}\tmean loss = {}".format(step, mean_loss_val))
+            saver.save(sess, model_path)
+            print("train done save model{}".format(model_path))
+        else:
+            #start = datetime.datetime.now()
+            #print("reloading model")
+            saver.restore(sess, model_path)
+            #print("{}reload model".format(model_path))
+
+            label_name_dict = {
+                0: "Brokenbeans",
+                1: "Peaberry",
+                2: "shellbean",
+                3: "Worms"
+            }
+            test_feed_dict = {
+                datas_placeholder: datas,
+                labels_placeholder: labels,
+                dropout_placeholdr: 0
+            }
+            predicted_labels_val = sess.run(predicted_labels, feed_dict=test_feed_dict)
+
+            for fpath, real_label, predicted_label in zip(fpaths, labels, predicted_labels_val):
+                # real_label_name = label_name_dict[real_label]
+                predicted_label_name = label_name_dict[predicted_label]
+                print("{}\t => {}".format(fpath, predicted_label_name))
+                #print(fpath, predicted_label_name)
+            dirListing = os.listdir(data_dir)
+            #print(len(dirListing))
+
+            #end = datetime.datetime.now()
+            #print("執行時間:", end - start)
+
+x = 1
+#if __name__ == '__main__':
+while True:
+    start = datetime.datetime.now()
+    #takephoto()
+    cut_rectangle()
+    cnn()
+
+    '''
+    evt = 'notify_me'  # 事件名稱
+    key = 'c3xo5EvpBX64fPEqxphcR4jBTzDh1r2joTDsB_BslOA'
+    val1 = parse.quote('執行第')  # value1參數值
+    val2 = parse.quote(str(x))  # value2參數值
+    val3 = parse.quote('次')  # value3參數值
+
+    url = (f'https://maker.ifttt.com/trigger/{evt}' +
+           f'/with/key/{key}?value1={val1}&value2={val2}&value3={val3}')
+
+    r = req.get(url)  # 執行IFTTT平台的webhooks
+    r.text  # 取得IFTTT的回應
+    x = x+1
+    '''
+    end = datetime.datetime.now()
+    print("完整執行時間:", end - start)
+    print('-----------------------------------------------------------')
+    #time.sleep(1)
+    tf.reset_default_graph()

+ 316 - 0
Coffee_project_train.py

@@ -0,0 +1,316 @@
+#coding=utf-8
+
+import os
+import numpy as np
+import tensorflow as tf
+import datetime
+import cv2
+import imutils
+import time
+import os.path
+
+from xml.dom import minidom
+from PIL import Image
+
+# -------------------------------------------
+extract_to = "C:\\Users\\User\\Desktop\\tfcoffebean\\allpic\\"
+dataset_images = "C:\\Users\\User\\Desktop\\tfcoffebean\\allpic\\"
+dataset_labels = "C:\\Users\\User\\Desktop\\tfcoffebean\\allpic\\"
+folderCharacter = "/"  # \\ is for windows
+xml_file = "C:\\Users\\User\\Desktop\\tfcoffebean\\xml_file.txt"
+object_xml_file = "C:\\Users\\User\\Desktop\\tfcoffebean\\xml_object.txt"
+
+def chkEnv():
+    if not os.path.exists(extract_to):
+        os.makedirs(extract_to)
+        print("no {} folder, created.".format(extract_to))
+    if (not os.path.exists(dataset_images)):
+        print("There is no such folder {}".format(dataset_images))
+        quit()
+    if (not os.path.exists(dataset_labels)):
+        print("There is no such folder {}".format(dataset_labels))
+        quit()
+    if (not os.path.exists(xml_file)):
+        f = open(xml_file, 'w')
+        f.close()
+        print("There is no xml file in {} ,created.".format(xml_file))
+    if (not os.path.exists(object_xml_file)):
+        f = open(object_xml_file, 'w')
+        f.close()
+        print("There is no object xml file in {} ,created.".format(object_xml_file))
+
+def getLabels(imgFile, xmlFile):
+    labelXML = minidom.parse(xmlFile)
+    labelName = []
+    labelXmin = []
+    labelYmin = []
+    labelXmax = []
+    labelYmax = []
+    totalW = 0
+    totalH = 0
+    countLabels = 0
+
+    tmpArrays = labelXML.getElementsByTagName("name")
+    for elem in tmpArrays:
+        labelName.append(str(elem.firstChild.data))
+    tmpArrays = labelXML.getElementsByTagName("xmin")
+    for elem in tmpArrays:
+        labelXmin.append(int(elem.firstChild.data))
+    tmpArrays = labelXML.getElementsByTagName("ymin")
+    for elem in tmpArrays:
+        labelYmin.append(int(elem.firstChild.data))
+    tmpArrays = labelXML.getElementsByTagName("xmax")
+    for elem in tmpArrays:
+        labelXmax.append(int(elem.firstChild.data))
+    tmpArrays = labelXML.getElementsByTagName("ymax")
+    for elem in tmpArrays:
+        labelYmax.append(int(elem.firstChild.data))
+
+    return labelName, labelXmin, labelYmin, labelXmax, labelYmax
+
+def write_lale_images(label, img, saveto, filename):
+    writePath = extract_to + label
+    print("WRITE:", writePath)
+    if not os.path.exists(writePath):
+        os.makedirs(writePath)
+
+        cv2.imwrite(writePath + folderCharacter + filename, img)
+
+def extract_lab_to_imgs():
+    chkEnv()
+    i = 0
+    for file in os.listdir(dataset_images):
+        filename, file_extension = os.path.splitext(file)
+        file_extension = file_extension.lower()
+        if (
+                file_extension == ".jpg" or file_extension == ".jpeg" or file_extension == ".png" or file_extension == ".bmp"):
+            print("Processing: ", dataset_images + file)
+            if not os.path.exists(dataset_labels + filename + ".xml"):
+                print("Cannot find the file {} for the image.".format(dataset_labels + filename + ".xml"))
+            else:
+                image_path = dataset_images + file
+                xml_path = dataset_labels + filename + ".xml"
+                labelName, labelXmin, labelYmin, labelXmax, labelYmax = getLabels(image_path, xml_path)
+                orgImage = cv2.imread(image_path)
+                image = orgImage.copy()
+                for id, label in enumerate(labelName):
+                    cv2.rectangle(image, (labelXmin[id], labelYmin[id]), (labelXmax[id], labelYmax[id]), (0, 255, 0), 2)
+                    label_area = orgImage[labelYmin[id]:labelYmax[id], labelXmin[id]:labelXmax[id]]
+                    label_img_filename = filename + "_" + str(id) + ".png"
+                    write_lale_images(labelName[id], label_area, extract_to, label_img_filename)
+                #cv2.imshow("Image", imutils.resize(image, width=700))
+                #k = cv2.waitKey(1)
+
+def turn_image():
+    dir1 = 'C:\\Users\\User\\Desktop\\tfcoffebean\\allpic\\brokenbean\\'
+    save_path1 = "C:\\Users\\User\\Desktop\\tfcoffebean\\brokenbean\\"
+
+    dir2 = 'C:\\Users\\User\\Desktop\\tfcoffebean\\allpic\\peaberrybean\\'
+    save_path2 = "C:\\Users\\User\\Desktop\\tfcoffebean\\peaberrybean\\"
+
+    dir3 = 'C:\\Users\\User\\Desktop\\tfcoffebean\\allpic\\\shellbean\\'
+    save_path3 = "C:\\Users\\User\\Desktop\\tfcoffebean\\\shellbean\\"
+
+    dir4 = 'C:\\Users\\User\\Desktop\\tfcoffebean\\allpic\\\wormbean\\'
+    save_path4 = "C:\\Users\\User\\Desktop\\tfcoffebean\\\wormbean\\"
+
+    x = [dir1, dir2, dir3, dir4]
+    y = [save_path1, save_path2, save_path3, save_path4]
+
+    i = 0
+    c = -1
+    for dir in x:
+        dirListing = os.listdir(dir)
+
+        c = c + 1
+        for filename in os.listdir(dir):
+            img = Image.open(dir + "/" + filename)
+            i = i + 1
+            img_rotate0 = img.transpose(Image.FLIP_LEFT_RIGHT)
+            img_rotate1 = img.rotate(90)
+            img_rotate2 = img.rotate(180)
+            img_rotate3 = img.rotate(270)
+
+            temp = filename[0: filename.find('_')]
+            temp2 = int(temp) + 1200000
+            temp3 = str(temp2) + '_' + str(i) + filename[filename.find('_'):]
+            img_rotate0.save(y[c] + temp3)
+            temp3 = str(temp2) + '_' + str(i + len(dirListing)) + filename[filename.find('_'):]
+            img_rotate1.save(y[c] + temp3)
+            temp3 = str(temp2) + '_' + str(i + len(dirListing) + len(dirListing)) + filename[filename.find('_'):]
+            img_rotate2.save(y[c] + temp3)
+            temp3 = str(temp2) + '_' + str(i + len(dirListing) + len(dirListing) + len(dirListing)) + filename[
+                                                                                                      filename.find(
+                                                                                                          '_'):]
+            img_rotate3.save(y[c] + temp3)
+        print(y[c])
+        print(len(dirListing))
+        print('done')
+
+def changename():
+    start = datetime.datetime.now()
+    path_name1 = 'C:\\Users\\User\\Desktop\\tfcoffebean\\brokenbean\\'
+    path_name2 = 'C:\\Users\\User\\Desktop\\tfcoffebean\\peaberrybean\\'
+    path_name3 = 'C:\\Users\\User\\Desktop\\tfcoffebean\\shellbean\\'
+    path_name4 = 'C:\\Users\\User\\Desktop\\tfcoffebean\\wormbean\\'
+
+    save_path = 'C:\\Users\\User\\Desktop\\tfcoffebean\\data\\'
+    image_size = 150
+    a = 1
+    for item in os.listdir(path_name1):
+        image_source = cv2.imread(path_name1 + item)
+        image = cv2.resize(image_source, (image_size, image_size), 0, 0, cv2.INTER_LINEAR)
+        cv2.imwrite(save_path + ('0' + '_' + str(a)) + ".png", image)
+        # os.rename(os.path.join(path_name1,item),os.path.join(path_name1,('0'+'_'+str(a)+'.png')))#os.path.join(path_name,item)表示找到每个文件的绝对路径并进行拼接操作
+        a += 1
+    b = 1
+    for item in os.listdir(path_name2):
+        image_source = cv2.imread(path_name2 + item)
+        image = cv2.resize(image_source, (image_size, image_size), 0, 0, cv2.INTER_LINEAR)
+        cv2.imwrite(save_path + ('1' + '_' + str(b)) + ".png", image)
+        # os.rename(os.path.join(path_name2,item),os.path.join(path_name2,('1'+'_'+str(b)+'.png')))#os.path.join(path_name,item)表示找到每个文件的绝对路径并进行拼接操作
+        b += 1
+    c = 1
+    for item in os.listdir(path_name3):  #
+        image_source = cv2.imread(path_name3 + item)
+        image = cv2.resize(image_source, (image_size, image_size), 0, 0, cv2.INTER_LINEAR)
+        cv2.imwrite(save_path + ('2' + '_' + str(c)) + ".png", image)
+        # os.rename(os.path.join(path_name3,item),os.path.join(path_name3,('2'+'_'+str(c)+'.png')))#os.path.join(path_name,item)表示找到每个文件的绝对路径并进行拼接操作
+        c += 1
+    d = 1
+    for item in os.listdir(path_name4):
+        image_source = cv2.imread(path_name4 + item)
+        image = cv2.resize(image_source, (image_size, image_size), 0, 0, cv2.INTER_LINEAR)
+        cv2.imwrite(save_path + ('3' + '_' + str(d)) + ".png", image)
+        # os.rename(os.path.join(path_name4,item),os.path.join(path_name4,('3'+'_'+str(d)+'.png')))#os.path.join(path_name,item)表示找到每个文件的绝对路径并进行拼接操作
+        d += 1
+
+    end = datetime.datetime.now()
+    print("更名執行時間:", end - start)
+
+def cnn():
+    # data file
+    data_dir = "data"
+
+    # train or test
+    train = True
+    # model address
+    model_path = "model/image_model"
+
+    def read_data(data_dir):
+        datas = []
+        labels = []
+        fpaths = []
+        for fname in os.listdir(data_dir):
+            fpath = os.path.join(data_dir, fname)
+            fpaths.append(fpath)
+            image = Image.open(fpath)
+            data = np.array(image) / 255.0
+            label = int(fname.split("_")[0])
+
+            datas.append(data)
+            labels.append(label)
+
+        datas = np.array(datas)
+        labels = np.array(labels)
+
+        print("shape of datas: {}\tshape of labels: {}".format(datas.shape, labels.shape))
+        return fpaths, datas, labels
+
+    fpaths, datas, labels = read_data(data_dir)
+
+    # num_classes = len(set(labels))
+    num_classes = 4
+
+    datas_placeholder = tf.placeholder(tf.float32, [None, 150, 150, 3])
+
+    labels_placeholder = tf.placeholder(tf.int32, [None])
+
+    dropout_placeholdr = tf.placeholder(tf.float32)
+
+    conv0 = tf.layers.conv2d(datas_placeholder, 20, 5, activation=tf.nn.relu)
+    pool0 = tf.layers.max_pooling2d(conv0, [2, 2], [2, 2])
+
+    conv1 = tf.layers.conv2d(pool0, 40, 4, activation=tf.nn.relu)
+    pool1 = tf.layers.max_pooling2d(conv1, [2, 2], [2, 2])
+
+    flatten = tf.layers.flatten(pool1)
+
+    fc = tf.layers.dense(flatten, 400, activation=tf.nn.relu)
+
+    dropout_fc = tf.layers.dropout(fc, dropout_placeholdr)
+
+    logits = tf.layers.dense(dropout_fc, num_classes)
+
+    predicted_labels = tf.arg_max(logits, 1)
+
+    losses = tf.nn.softmax_cross_entropy_with_logits(
+        labels=tf.one_hot(labels_placeholder, num_classes),
+        logits=logits
+    )
+    mean_loss = tf.reduce_mean(losses)
+
+    optimizer = tf.train.AdamOptimizer(learning_rate=1e-2).minimize(losses)
+
+    saver = tf.train.Saver()
+
+    with tf.Session() as sess:
+
+        if train:
+            start = datetime.datetime.now()
+            print("train mode")
+
+            sess.run(tf.global_variables_initializer())
+
+            train_feed_dict = {
+                datas_placeholder: datas,
+                labels_placeholder: labels,
+                dropout_placeholdr: 0.25
+            }
+            for step in range(500):
+                _, mean_loss_val = sess.run([optimizer, mean_loss], feed_dict=train_feed_dict)
+
+                if step % 5 == 0:
+                    print("step = {}\tmean loss = {}".format(step, mean_loss_val))
+            saver.save(sess, model_path)
+            print("train done,save model{}".format(model_path))
+            end = datetime.datetime.now()
+            print("訓練執行時間:", end - start)
+        else:
+
+            print("test mode")
+            saver.restore(sess, model_path)
+            print("{}reload model".format(model_path))
+
+            label_name_dict = {
+                0: "Brokenbeans",
+                1: "Peaberry",
+                2: "shellbean",
+                3: "Worms"
+            }
+            test_feed_dict = {
+                datas_placeholder: datas,
+                labels_placeholder: labels,
+                dropout_placeholdr: 0
+            }
+            predicted_labels_val = sess.run(predicted_labels, feed_dict=test_feed_dict)
+
+            for fpath, real_label, predicted_label in zip(fpaths, labels, predicted_labels_val):
+                # real_label_name = label_name_dict[real_label]
+                predicted_label_name = label_name_dict[predicted_label]
+                print("{}\t => {}".format(fpath, predicted_label_name))
+            dirListing = os.listdir(data_dir)
+            print(len(dirListing))
+
+
+if __name__ == '__main__':
+    start = datetime.datetime.now()
+    extract_lab_to_imgs()
+    end = datetime.datetime.now()
+    print("xml轉換img執行時間:", end - start)
+    start1 = datetime.datetime.now()
+    turn_image()
+    end1 = datetime.datetime.now()
+    print("轉動圖片執行時間:", end1 - start1)
+    changename()
+    cnn()

+ 20 - 0
Coffee_resize.py

@@ -0,0 +1,20 @@
+
+import cv2
+import os
+
+image_size = 150
+source_path = "test/"
+target_path = "test/"
+
+if not os.path.exists(target_path):
+    os.makedirs(target_path)
+
+image_list = os.listdir(source_path)
+
+i = 0
+for file in image_list:
+    i = i + 1
+    image_source = cv2.imread(source_path + file)
+    image = cv2.resize(image_source, (image_size, image_size), 0, 0, cv2.INTER_LINEAR)
+    cv2.imwrite(target_path + ('0'+'_'+str(i)) + ".png", image)
+print("done")