博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
使用TensorFlow提供的slim模型来训练数据模型供iOS使用
阅读量:6246 次
发布时间:2019-06-22

本文共 7321 字,大约阅读时间需要 24 分钟。

1、

cd /Users/javalong/Downloadgit clone https://github.com/tensorflow/models/

2、  数据可以是slim提供的数据集或者是自己采集的图片

2.1、下载slim提供的数据集flowers

2.1.1、设置下载目录命令:

DATA_DIR=/Users/javalong/Desktop/Test/output/flowers

2.1.2、进入到slim模型目录命令:

cd /Users/javalong/Downloads/models-master/slim

2.1.3、下载数据集命令:

python3 download_and_convert_data.py \

    --dataset_name=flowers \

    --dataset_dir="${DATA_DIR}"

2.1.4、查看目录下的文件命令:

ls ${DATA_DIR}

得到:

flowers_train-00000-of-00005.tfrecord

...

flowers_train-00004-of-00005.tfrecord

flowers_validation-00000-of-00005.tfrecord

...

flowers_validation-00004-of-00005.tfrecord

labels.txt

2.2、我们可以看到下载slim提供的数据文件是tfrecord格式,所以我们要训练自己采集的图片,第一步先将图片转换成tfrecord格式。

2.2.1、将图片转换成TFRecord文件,需要安装的软件

pip3 install Pillow

pip3 install matplotlib

2.2.2、在/Users/javalong/Downloads/models-master/slim下创建一个fu_img_to_tfrecord.py文件。

如图:

a665c93bb752afbd2964b75354a376c14c42eea6

2.2.3、fu_img_to_tfrecord.py的内容为:

import os import os.path import tensorflow as tf from PIL import Image  import matplotlib.pyplot as plt import sysimport pprintpp = pprint.PrettyPrinter(indent = 2)data_dir=sys.argv[1]train_dir=sys.argv[2]classes=[]for dir in os.listdir(data_dir):    path = os.path.join(data_dir, dir)    if os.path.isdir(path):        classes.append(dir)train= tf.python_io.TFRecordWriter(train_dir+"/iss_train.tfrecord") test= tf.python_io.TFRecordWriter(train_dir+"/iss_test.tfrecord") def int64_feature(values):    if not isinstance(values, (tuple, list)):        values = [values]    return tf.train.Feature(int64_list=tf.train.Int64List(value=values))def bytes_feature(values):    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))def image_to_tfexample(image_data, image_format, height, width, class_id):    return tf.train.Example(features=tf.train.Features(feature={         'image/encoded': bytes_feature(image_data),        'image/format': bytes_feature(image_format),        'image/class/label': int64_feature(class_id),        'image/height': int64_feature(height),        'image/width': int64_feature(width),    }))def get_extension(path):    return os.path.splitext(path)[1] class ImageReader(object):  """Helper class that provides TensorFlow image coding utilities."""  def __init__(self):    # Initializes function that decodes RGB JPEG data.    self._decode_jpeg_data = tf.placeholder(dtype=tf.string)    self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3)  def read_image_dims(self, sess, image_data):    image = self.decode_jpeg(sess, image_data)    return image.shape[0], image.shape[1]    def decode_jpeg(self, sess, image_data):    image = sess.run(self._decode_jpeg,                     feed_dict={self._decode_jpeg_data: image_data})    assert len(image.shape) == 3    assert image.shape[2] == 3    return imagedef write_label_file(labels_to_class_names, dataset_dir,                     filename='lables.txt'):  """Writes a file with the list of class names.  Args:    labels_to_class_names: A map of (integer) labels to class names.    dataset_dir: The directory in which the labels file should be written.    filename: The filename where the class names are written.  """  labels_filename = os.path.join(dataset_dir, filename)  with tf.gfile.Open(labels_filename, 'w') as f:    for label in labels_to_class_names:      class_name = labels_to_class_names[label]      f.write('%d:%s\n' % (label, class_name))lable_file=train_dir+'/lable.txt'lable_input=open(lable_file, 'w')info_file=train_dir+'/meta_info.txt'test_num=0;train_num=0;with tf.Graph().as_default():    image_reader = ImageReader()    with tf.Session('') as sess:         for index,name in enumerate(classes):            lable_input.write('%d:%s\n' % (index, name))              class_path=data_dir+'/'+name+'/'            for num, img_name in enumerate(os.listdir(class_path)):                 img_path=class_path+img_name                                 format=get_extension(img_name)                image_data = tf.gfile.FastGFile(img_path, 'rb').read()                height, width = image_reader.read_image_dims(sess, image_data)                example = image_to_tfexample(image_data, b'jpg', height, width, index)                if num % 5 == 0:                    test_num= test_num+1                    #pass                    #print img_path + " " + str(index) + " " + name                    test.write(example.SerializeToString())                 else:                    train_num=train_num+1                    train.write(example.SerializeToString())                    #print img_path + " " + str(index) + " " + nametrain.close()test.close()info_input=open(info_file,'w')info_input.write("train_num:"+str(train_num)+'\n')info_input.write("test_num:"+str(test_num)+'\n')info_input.close()lable_input.close()

2.2.4、执行转换命令:

python3 /Users/javalong/Downloads/models-master/slim/fu_img_to_tfrecord.py /Users/javalong/Desktop/flowers /Users/javalong/Desktop/flower_record

注:

2.2.5/Users/javalong/Desktop/flowers是存放采集的图片,如图:

a9fb532a0a37402e2a2063e65e8518763175e0e4

2.2.6/Users/javalong/Desktop/flower_record是生成的tfrecord格式文件存放目录。最终生成的文件如图:

f2ff611499c9a63665c7f550c5804f95e8246afd

2.2.7使用/Users/javalong/Desktop/flowers目录的子目录名作为分类文本会存储到生成的label.txt中。如图:

e41d2c85d9b9be759fa8bca267c6ee00e0e272be

2.2.8fu_img_to_tfrecord.py功能实现参考/Users/javalong/Downloads/models-master/slim/datasets/download_and_convert_flowers.py文件

3、用预训练数据集inception_v3来训练数据集flowers

3.1、设置相应的目录:

DATASET_DIR=/Users/javalong/Desktop/Test/output/flowers

CHECKPOINT_PATH=/Users/javalong/Desktop/Test/output/inception/inception_v3.ckpt

TRAIN_DIR=/Users/javalong/Desktop/Test/output/tran

3.2、训练命令:

python3 train_image_classifier.py \

    --train_dir=${TRAIN_DIR} \

    --dataset_dir=${DATASET_DIR} \

    --dataset_name=flowers \

    --dataset_split_name=train \

    --model_name=inception_v3 \

    --checkpoint_path=${CHECKPOINT_PATH} \

    --checkpoint_exclude_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \

    --trainable_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \

    --clone_on_cpu=true

4、生成.pb文件

4.1、在/Users/javalong/Downloads/models-master/slim下创建一个bbb.py文件。

如图:

c3fc804bf05dacac52d013da34c160cdeb47c057

4.2、bbb.py的内容为:

import osimport tensorflow as tfimport tensorflow.contrib.slim as slimfrom nets import inceptionfrom nets import inception_v1from nets import inception_v3from nets import nets_factoryfrom tensorflow.python.framework import graph_utilfrom tensorflow.python.platform import gfilefrom google.protobuf import text_formatcheckpoint_path = tf.train.latest_checkpoint('/Users/javalong/Desktop/Test/output/tran')with tf.Graph().as_default() as graph:    input_tensor = tf.placeholder(tf.float32, shape=(None, 299, 299, 3), name='input_image')    with tf.Session() as sess:      #  with tf.variable_scope('model') as scope:            with slim.arg_scope(inception.inception_v3_arg_scope()):                logits, end_points = inception.inception_v3(input_tensor, num_classes=5, is_training=False)    saver = tf.train.Saver()    saver.restore(sess, checkpoint_path)    output_node_names = 'InceptionV3/Predictions/Reshape_1'         input_graph_def = graph.as_graph_def()    output_graph_def = graph_util.convert_variables_to_constants(sess, input_graph_def, output_node_names.split(","))    with open('/Users/javalong/Desktop/Test/output/output_graph_nodes.txt', 'w') as f:        f.write(text_format.MessageToString(output_graph_def))     output_graph = '/Users/javalong/Desktop/Test/output/inception_v3_final.pb'    with gfile.FastGFile(output_graph, 'wb') as f:        f.write(output_graph_def.SerializeToString())

5、优化模型并去掉iOS不支持的算子 

转载地址:http://rymia.baihongyu.com/

你可能感兴趣的文章
用UglifyJS2合并压缩混淆JS代码
查看>>
Angular2入门:TypeScript的类型 - 对象解构
查看>>
apache spark kubernets 部署试用
查看>>
Windows下python3生成UTF8的CSV文件和sha256sum踩坑记录
查看>>
SPIHT 编码原理,代码,应用,专利问题
查看>>
JBPM4 读书笔记点滴
查看>>
Ext.net 动态生成控件
查看>>
10个强大的Javascript表单验证插件推荐
查看>>
神奇HVXC的MOS 分
查看>>
用SQL游标将1列中的数据分解成3列
查看>>
free 与 delete
查看>>
Qt之对话框设计——可扩展对话框
查看>>
【dotnetfx】Microsoft .NET Framework 3.5 sp1离线安装解决方案
查看>>
<===最困难的时候,就是距离成功不远了===>
查看>>
在图片上显示左右箭头的翻页代码
查看>>
eclipse插件开发--获取当前项目路径
查看>>
Oracle多行数据的合并
查看>>
从零开始编写自己的C#框架(3)——开发规范
查看>>
ZigBee绑定细节
查看>>
Objective-c中定义成员变量
查看>>