直接看代码例子,有详细注释!!
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
|
import tensorflow as tf import numpy as np d = np.arange( 0 , 60 ).reshape([ 6 , 10 ]) # 将array转化为tensor data = tf.data.Dataset.from_tensor_slices(d) # 从data数据集中按顺序抽取buffer_size个样本放在buffer中,然后打乱buffer中的样本 # buffer中样本个数不足buffer_size,继续从data数据集中安顺序填充至buffer_size, # 此时会再次打乱 data = data.shuffle(buffer_size = 3 ) # 每次从buffer中抽取4个样本 data = data.batch( 4 ) # 将data数据集重复,其实就是2个epoch数据集 data = data.repeat( 2 ) # 构造获取数据的迭代器 iters = data.make_one_shot_iterator() # 每次从迭代器中获取一批数据 batch = iters.get_next() sess = tf.Session() sess.run(batch) # 数据集完成遍历完之后,继续抽取的话会报错:OutOfRangeError |
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
|
In [ 21 ]: d Out[ 21 ]: array([[ 0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ], [ 10 , 11 , 12 , 13 , 14 , 15 , 16 , 17 , 18 , 19 ], [ 20 , 21 , 22 , 23 , 24 , 25 , 26 , 27 , 28 , 29 ], [ 30 , 31 , 32 , 33 , 34 , 35 , 36 , 37 , 38 , 39 ], [ 40 , 41 , 42 , 43 , 44 , 45 , 46 , 47 , 48 , 49 ], [ 50 , 51 , 52 , 53 , 54 , 55 , 56 , 57 , 58 , 59 ]]) In [ 22 ]: sess.run(batch) Out[ 22 ]: array([[ 0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ], [ 30 , 31 , 32 , 33 , 34 , 35 , 36 , 37 , 38 , 39 ], [ 20 , 21 , 22 , 23 , 24 , 25 , 26 , 27 , 28 , 29 ], [ 10 , 11 , 12 , 13 , 14 , 15 , 16 , 17 , 18 , 19 ]]) In [ 23 ]: sess.run(batch) Out[ 23 ]: array([[ 40 , 41 , 42 , 43 , 44 , 45 , 46 , 47 , 48 , 49 ], [ 50 , 51 , 52 , 53 , 54 , 55 , 56 , 57 , 58 , 59 ]]) |
从输出结果可以看出:
shuffle是按顺序将数据放入buffer里面的;
当repeat函数在shuffle之后的话,是将一个epoch的数据集抽取完毕,再进行下一个epoch的。
那么,当repeat函数在shuffle之前会怎么样呢?如下:
1
2
3
4
5
|
data = data.repeat( 2 ) data = data.shuffle(buffer_size = 3 ) data = data.batch( 4 ) |
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
|
In [ 25 ]: sess.run(batch) Out[ 25 ]: array([[ 10 , 11 , 12 , 13 , 14 , 15 , 16 , 17 , 18 , 19 ], [ 20 , 21 , 22 , 23 , 24 , 25 , 26 , 27 , 28 , 29 ], [ 0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ], [ 40 , 41 , 42 , 43 , 44 , 45 , 46 , 47 , 48 , 49 ]]) In [ 26 ]: sess.run(batch) Out[ 26 ]: array([[ 50 , 51 , 52 , 53 , 54 , 55 , 56 , 57 , 58 , 59 ], [ 0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ], [ 30 , 31 , 32 , 33 , 34 , 35 , 36 , 37 , 38 , 39 ], [ 30 , 31 , 32 , 33 , 34 , 35 , 36 , 37 , 38 , 39 ]]) In [ 27 ]: sess.run(batch) Out[ 27 ]: array([[ 10 , 11 , 12 , 13 , 14 , 15 , 16 , 17 , 18 , 19 ], [ 50 , 51 , 52 , 53 , 54 , 55 , 56 , 57 , 58 , 59 ], [ 20 , 21 , 22 , 23 , 24 , 25 , 26 , 27 , 28 , 29 ], [ 40 , 41 , 42 , 43 , 44 , 45 , 46 , 47 , 48 , 49 ]]) |
可以看出,其实它就是先将数据集复制一遍,然后把两个epoch当成同一个新的数据集,一直shuffle和batch下去。
以上这篇TensorFlow dataset.shuffle、batch、repeat的使用详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持服务器之家。
原文链接:https://blog.csdn.net/sgyuanshi/article/details/90183610