先放关键代码:
1
2
|
i = tf.train.range_input_producer(NUM_EXPOCHES, num_epochs = 1 , shuffle = False ).dequeue() inputs = tf. slice (array, [i * BATCH_SIZE], [BATCH_SIZE]) |
原理解析:
第一行会产生一个队列,队列包含0到NUM_EXPOCHES-1的元素,如果num_epochs有指定,则每个元素只产生num_epochs次,否则循环产生。shuffle指定是否打乱顺序,这里shuffle=False表示队列的元素是按0到NUM_EXPOCHES-1的顺序存储。在Graph运行的时候,每个线程从队列取出元素,假设值为i,然后按照第二行代码切出array的一小段数据作为一个batch。例如NUM_EXPOCHES=3,如果num_epochs=2,则队列的内容是这样子;
0,1,2,0,1,2
队列只有6个元素,这样在训练的时候只能产生6个batch,迭代6次以后训练就结束。
如果num_epochs不指定,则队列内容是这样子:
0,1,2,0,1,2,0,1,2,0,1,2...
队列可以一直生成元素,训练的时候可以产生无限的batch,需要自己控制什么时候停止训练。
下面是完整的演示代码。
数据文件test.txt内容:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
main.py内容:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
|
import tensorflow as tf import codecs BATCH_SIZE = 6 NUM_EXPOCHES = 5 def input_producer(): array = codecs. open ( "test.txt" ).readlines() array = map ( lambda line: line.strip(), array) i = tf.train.range_input_producer(NUM_EXPOCHES, num_epochs = 1 , shuffle = False ).dequeue() inputs = tf. slice (array, [i * BATCH_SIZE], [BATCH_SIZE]) return inputs class Inputs( object ): def __init__( self ): self .inputs = input_producer() def main( * args, * * kwargs): inputs = Inputs() init = tf.group(tf.initialize_all_variables(), tf.initialize_local_variables()) sess = tf.Session() coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess = sess, coord = coord) sess.run(init) try : index = 0 while not coord.should_stop() and index< 10 : datalines = sess.run(inputs.inputs) index + = 1 print ( "step: %d, batch data: %s" % (index, str (datalines))) except tf.errors.OutOfRangeError: print ( "Done traing:-------Epoch limit reached" ) except KeyboardInterrupt: print ( "keyboard interrput detected, stop training" ) finally : coord.request_stop() coord.join(threads) sess.close() del sess if __name__ = = "__main__" : main() |
输出:
1
2
3
4
5
6
|
step: 1 , batch data: [ '1' '2' '3' '4' '5' '6' ] step: 2 , batch data: [ '7' '8' '9' '10' '11' '12' ] step: 3 , batch data: [ '13' '14' '15' '16' '17' '18' ] step: 4 , batch data: [ '19' '20' '21' '22' '23' '24' ] step: 5 , batch data: [ '25' '26' '27' '28' '29' '30' ] Done traing: - - - - - - - Epoch limit reached |
如果range_input_producer去掉参数num_epochs=1,则输出:
1
2
3
4
5
6
7
8
9
10
|
step: 1 , batch data: [ '1' '2' '3' '4' '5' '6' ] step: 2 , batch data: [ '7' '8' '9' '10' '11' '12' ] step: 3 , batch data: [ '13' '14' '15' '16' '17' '18' ] step: 4 , batch data: [ '19' '20' '21' '22' '23' '24' ] step: 5 , batch data: [ '25' '26' '27' '28' '29' '30' ] step: 6 , batch data: [ '1' '2' '3' '4' '5' '6' ] step: 7 , batch data: [ '7' '8' '9' '10' '11' '12' ] step: 8 , batch data: [ '13' '14' '15' '16' '17' '18' ] step: 9 , batch data: [ '19' '20' '21' '22' '23' '24' ] step: 10 , batch data: [ '25' '26' '27' '28' '29' '30' ] |
有一点需要注意,文件总共有35条数据,BATCH_SIZE = 6表示每个batch包含6条数据,NUM_EXPOCHES = 5表示产生5个batch,如果NUM_EXPOCHES =6,则总共需要36条数据,就会报如下错误:
1
2
|
InvalidArgumentError (see above for traceback): Expected size[ 0 ] in [ 0 , 5 ], but got 6 [[Node: Slice = Slice [Index = DT_INT32, T = DT_STRING, _device = "/job:localhost/replica:0/task:0/cpu:0" ]( Slice / input , Slice / begin / _5, Slice / size)]] |
错误信息的意思是35/BATCH_SIZE=5,即NUM_EXPOCHES 的取值能只能在0到5之间。
以上这篇tensorflow使用range_input_producer多线程读取数据实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持服务器之家。
原文链接:https://blog.csdn.net/lyg5623/article/details/69387917