服务器之家

服务器之家 > 正文

Tensorflow之构建自己的图片数据集TFrecords的方法

时间:2021-01-13 00:05     来源/作者:tengxing007

学习谷歌的深度学习终于有点眉目了,给大家分享我的Tensorflow学习历程。

tensorflow的官方中文文档比较生涩,数据集一直采用的MNIST二进制数据集。并没有过多讲述怎么构建自己的图片数据集tfrecords。

流程是:制作数据集—读取数据集—-加入队列

先贴完整的代码:

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
#encoding=utf-8
import os
import tensorflow as tf
from PIL import Image
 
cwd = os.getcwd()
 
classes = {'test','test1','test2'}
#制作二进制数据
def create_record():
  writer = tf.python_io.TFRecordWriter("train.tfrecords")
  for index, name in enumerate(classes):
    class_path = cwd +"/"+ name+"/"
    for img_name in os.listdir(class_path):
      img_path = class_path + img_name
      img = Image.open(img_path)
      img = img.resize((64, 64))
      img_raw = img.tobytes() #将图片转化为原生bytes
      print index,img_raw
      example = tf.train.Example(
        features=tf.train.Features(feature={
          "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
          'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
        }))
      writer.write(example.SerializeToString())
  writer.close()
 
data = create_record()
 
#读取二进制数据
def read_and_decode(filename):
  # 创建文件队列,不限读取的数量
  filename_queue = tf.train.string_input_producer([filename])
  # create a reader from file queue
  reader = tf.TFRecordReader()
  # reader从文件队列中读入一个序列化的样本
  _, serialized_example = reader.read(filename_queue)
  # get feature from serialized example
  # 解析符号化的样本
  features = tf.parse_single_example(
    serialized_example,
    features={
      'label': tf.FixedLenFeature([], tf.int64),
      'img_raw': tf.FixedLenFeature([], tf.string)
    }
  )
  label = features['label']
  img = features['img_raw']
  img = tf.decode_raw(img, tf.uint8)
  img = tf.reshape(img, [64, 64, 3])
  img = tf.cast(img, tf.float32) * (1. / 255) - 0.5
  label = tf.cast(label, tf.int32)
  return img, label
 
if __name__ == '__main__':
  if 0:
    data = create_record("train.tfrecords")
  else:
    img, label = read_and_decode("train.tfrecords")
    print "tengxing",img,label
    #使用shuffle_batch可以随机打乱输入 next_batch挨着往下取
    # shuffle_batch才能实现[img,label]的同步,也即特征和label的同步,不然可能输入的特征和label不匹配
    # 比如只有这样使用,才能使img和label一一对应,每次提取一个image和对应的label
    # shuffle_batch返回的值就是RandomShuffleQueue.dequeue_many()的结果
    # Shuffle_batch构建了一个RandomShuffleQueue,并不断地把单个的[img,label],送入队列中
    img_batch, label_batch = tf.train.shuffle_batch([img, label],
                          batch_size=4, capacity=2000,
                          min_after_dequeue=1000)
 
    # 初始化所有的op
    init = tf.initialize_all_variables()
 
    with tf.Session() as sess:
      sess.run(init)
      # 启动队列
      threads = tf.train.start_queue_runners(sess=sess)
      for i in range(5):
        print img_batch.shape,label_batch
        val, l = sess.run([img_batch, label_batch])
        # l = to_categorical(l, 12)
        print(val.shape, l)

制作数据集

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
#制作二进制数据
def create_record():
  cwd = os.getcwd()
  classes = {'1','2','3'}
  writer = tf.python_io.TFRecordWriter("train.tfrecords")
  for index, name in enumerate(classes):
    class_path = cwd +"/"+ name+"/"
    for img_name in os.listdir(class_path):
      img_path = class_path + img_name
      img = Image.open(img_path)
      img = img.resize((28, 28))
      img_raw = img.tobytes() #将图片转化为原生bytes
      #print index,img_raw
      example = tf.train.Example(
        features=tf.train.Features(
          feature={
            "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
            'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
          }
        )
      )
      writer.write(example.SerializeToString())
  writer.close()

TFRecords文件包含了tf.train.Example 协议内存块(protocol buffer)(协议内存块包含了字段 Features)。我们可以写一段代码获取你的数据, 将数据填入到Example协议内存块(protocol buffer),将协议内存块序列化为一个字符串, 并且通过tf.python_io.TFRecordWriter 写入到TFRecords文件。

读取数据集

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
#读取二进制数据
def read_and_decode(filename):
  # 创建文件队列,不限读取的数量
  filename_queue = tf.train.string_input_producer([filename])
  # create a reader from file queue
  reader = tf.TFRecordReader()
  # reader从文件队列中读入一个序列化的样本
  _, serialized_example = reader.read(filename_queue)
  # get feature from serialized example
  # 解析符号化的样本
  features = tf.parse_single_example(
    serialized_example,
    features={
      'label': tf.FixedLenFeature([], tf.int64),
      'img_raw': tf.FixedLenFeature([], tf.string)
    }
  )
  label = features['label']
  img = features['img_raw']
  img = tf.decode_raw(img, tf.uint8)
  img = tf.reshape(img, [64, 64, 3])
  img = tf.cast(img, tf.float32) * (1. / 255) - 0.5
  label = tf.cast(label, tf.int32)
  return img, label

一个Example中包含Features,Features里包含Feature(这里没s)的字典。最后,Feature里包含有一个 FloatList, 或者ByteList,或者Int64List

加入队列

?
1
2
3
4
5
6
7
8
9
with tf.Session() as sess:
      sess.run(init)
      # 启动队列
      threads = tf.train.start_queue_runners(sess=sess)
      for i in range(5):
        print img_batch.shape,label_batch
        val, l = sess.run([img_batch, label_batch])
        # l = to_categorical(l, 12)
        print(val.shape, l)

这样就可以的到和tensorflow官方的二进制数据集了,

注意:

  1. 启动队列那条code不要忘记,不然卡死
  2. 使用的时候记得使用val和l,不然会报类型错误:TypeError: The value of a feed cannot be a tf.Tensor object. Acceptable feed values include Python scalars, strings, lists, or numpy ndarrays.
  3. 算交叉熵时候:cross_entropy=tf.nn.sparse_softmax_cross_entropy_with_logits(logits,labels)算交叉熵
  4. 最后评估的时候用tf.nn.in_top_k(logits,labels,1)选logits最大的数的索引和label比较
  5. cross_entropy = -tf.reduce_sum(y_*tf.log(y_conv))算交叉熵,所以label必须转成one-hot向量

实例2:将图片文件夹下的图片转存tfrecords的数据集。

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
############################################################################################
#!/usr/bin/python2.7
# -*- coding: utf-8 -*-
#Author : zhaoqinghui
#Date  : 2016.5.10
#Function: image convert to tfrecords 
#############################################################################################
 
import tensorflow as tf
import numpy as np
import cv2
import os
import os.path
from PIL import Image
 
#参数设置
###############################################################################################
train_file = 'train.txt' #训练图片
name='train'   #生成train.tfrecords
output_directory='./tfrecords'
resize_height=32 #存储图片高度
resize_width=32 #存储图片宽度
###############################################################################################
def _int64_feature(value):
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
 
def _bytes_feature(value):
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
 
def load_file(examples_list_file):
  lines = np.genfromtxt(examples_list_file, delimiter=" ", dtype=[('col1', 'S120'), ('col2', 'i8')])
  examples = []
  labels = []
  for example, label in lines:
    examples.append(example)
    labels.append(label)
  return np.asarray(examples), np.asarray(labels), len(lines)
 
def extract_image(filename, resize_height, resize_width):
  image = cv2.imread(filename)
  image = cv2.resize(image, (resize_height, resize_width))
  b,g,r = cv2.split(image)    
  rgb_image = cv2.merge([r,g,b])   
  return rgb_image
 
def transform2tfrecord(train_file, name, output_directory, resize_height, resize_width):
  if not os.path.exists(output_directory) or os.path.isfile(output_directory):
    os.makedirs(output_directory)
  _examples, _labels, examples_num = load_file(train_file)
  filename = output_directory + "/" + name + '.tfrecords'
  writer = tf.python_io.TFRecordWriter(filename)
  for i, [example, label] in enumerate(zip(_examples, _labels)):
    print('No.%d' % (i))
    image = extract_image(example, resize_height, resize_width)
    print('shape: %d, %d, %d, label: %d' % (image.shape[0], image.shape[1], image.shape[2], label))
    image_raw = image.tostring()
    example = tf.train.Example(features=tf.train.Features(feature={
      'image_raw': _bytes_feature(image_raw),
      'height': _int64_feature(image.shape[0]),
      'width': _int64_feature(image.shape[1]),
      'depth': _int64_feature(image.shape[2]),
      'label': _int64_feature(label)
    }))
    writer.write(example.SerializeToString())
  writer.close()
 
def disp_tfrecords(tfrecord_list_file):
  filename_queue = tf.train.string_input_producer([tfrecord_list_file])
  reader = tf.TFRecordReader()
  _, serialized_example = reader.read(filename_queue)
  features = tf.parse_single_example(
    serialized_example,
 features={
     'image_raw': tf.FixedLenFeature([], tf.string),
     'height': tf.FixedLenFeature([], tf.int64),
     'width': tf.FixedLenFeature([], tf.int64),
     'depth': tf.FixedLenFeature([], tf.int64),
     'label': tf.FixedLenFeature([], tf.int64)
   }
  )
  image = tf.decode_raw(features['image_raw'], tf.uint8)
  #print(repr(image))
  height = features['height']
  width = features['width']
  depth = features['depth']
  label = tf.cast(features['label'], tf.int32)
  init_op = tf.initialize_all_variables()
  resultImg=[]
  resultLabel=[]
  with tf.Session() as sess:
    sess.run(init_op)
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    for i in range(21):
      image_eval = image.eval()
      resultLabel.append(label.eval())
      image_eval_reshape = image_eval.reshape([height.eval(), width.eval(), depth.eval()])
      resultImg.append(image_eval_reshape)
      pilimg = Image.fromarray(np.asarray(image_eval_reshape))
      pilimg.show()
    coord.request_stop()
    coord.join(threads)
    sess.close()
  return resultImg,resultLabel
 
def read_tfrecord(filename_queuetemp):
  filename_queue = tf.train.string_input_producer([filename_queuetemp])
  reader = tf.TFRecordReader()
  _, serialized_example = reader.read(filename_queue)
  features = tf.parse_single_example(
    serialized_example,
    features={
     'image_raw': tf.FixedLenFeature([], tf.string),
     'width': tf.FixedLenFeature([], tf.int64),
     'depth': tf.FixedLenFeature([], tf.int64),
     'label': tf.FixedLenFeature([], tf.int64)
   }
  )
  image = tf.decode_raw(features['image_raw'], tf.uint8)
  # image
  tf.reshape(image, [256, 256, 3])
  # normalize
  image = tf.cast(image, tf.float32) * (1. /255) - 0.5
  # label
  label = tf.cast(features['label'], tf.int32)
  return image, label
 
def test():
  transform2tfrecord(train_file, name , output_directory, resize_height, resize_width) #转化函数  
  img,label=disp_tfrecords(output_directory+'/'+name+'.tfrecords') #显示函数
  img,label=read_tfrecord(output_directory+'/'+name+'.tfrecords') #读取函数
  print label
 
if __name__ == '__main__':
  test()

这样就可以得到自己专属的数据集.tfrecords了  ,它可以直接用于tensorflow的数据集。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持服务器之家。

原文链接:http://blog.csdn.net/tengxing007/article/details/56847828

相关文章

热门资讯

2020微信伤感网名听哭了 让对方看到心疼的伤感网名大全
2020微信伤感网名听哭了 让对方看到心疼的伤感网名大全 2019-12-26
Intellij idea2020永久破解,亲测可用!!!
Intellij idea2020永久破解,亲测可用!!! 2020-07-29
背刺什么意思 网络词语背刺是什么梗
背刺什么意思 网络词语背刺是什么梗 2020-05-22
歪歪漫画vip账号共享2020_yy漫画免费账号密码共享
歪歪漫画vip账号共享2020_yy漫画免费账号密码共享 2020-04-07
苹果12mini价格表官网报价 iPhone12mini全版本价格汇总
苹果12mini价格表官网报价 iPhone12mini全版本价格汇总 2020-11-13
返回顶部