Put the key code first:
i = Tf.train.range_input_producer (Num_expoches, Num_epochs=1, Shuffle=false). Dequeue () Inputs
= Tf.slice (Array, [ I * batch_size], [batch_size])
Principle Analysis:
The first line produces a queue that contains 0 to num_expoches-1 elements, and if Num_epochs is specified, each element produces only num_epochs times, otherwise the loop is generated. Shuffle Specifies whether the order is scrambled, where shuffle=false indicates that the elements of the queue are stored in the order of 0 to Num_expoches-1. When graph runs, each thread takes the element out of the queue, assumes the value I, and then cuts a small piece of data from the array as a batch according to the second line of code. For example num_expoches=3, if num_epochs=2, then the contents of the queue are such;
0,1,2,0,1,2
There are only 6 elements in the queue, so that only 6 batch can be generated during training, and 6 iterations later the training is over.
If Num_epochs is not specified, the queue content looks like this:
0,1,2,0,1,2,0,1,2,0,1,2 ...
Queues can generate elements all the time, training can produce unlimited batch, and you need to control when to stop training.
Here is the complete demo code.
Data file Test.txt content:
1
2
3
4
5
6
7 8 9
30 (a) (a)
35
main.py content:
Import TensorFlow as tf import codecs batch_size = 6 num_expoches = 5 def input_producer (): array = Codecs.open ("Te St.txt "). ReadLines () array = map (lambda Line:line.strip (), array) i = Tf.train.range_input_producer (Num_expoches, Nu M_epochs=1, Shuffle=false). Dequeue () inputs = Tf.slice (array, [i * batch_size], [batch_size]) return inputs Clas s inputs (object): Def __init__ (self): self.inputs = Input_producer () def main (*args, **kwargs): inputs = Inputs () init = Tf.group (Tf.initialize_all_variables (), Tf.initialize_local_variables ()) Sess = TF. Session () Coord = Tf.train.Coordinator () threads = Tf.train.start_queue_runners (sess=sess, Coord=coord) SESS.R Un (init) try:index = 0 While not coord.should_stop () and index<10:datalines = sess.ru N (inputs.inputs) index = 1 print ("Step:%d, Batch data:%s"% (index, str (datalines)) except Tf.errors.OutOfRaNgeerror:print ("Done traing:-------Epoch limit reached") except Keyboardinterrupt:print ("keyboard in Terrput detected, stop training ") Finally:coord.request_stop () coord.join (threads) sess.close () d El Sess If __name__ = = "__main__": Main ()
Output:
Step:1, Batch data: [' 1 ' 2 ' 3 ' 4 ' 5 ' 6 ']
step:2, batch data: [' 7 ' ' 8 ' ' 9 ' '] '
step:3, Batch Da TA: [' ' ' '] step:4, batch data: [' ' ' ' ' ' ' ' '
]
step:5, batch data: [' 25 ']. ['] "." ' ' is ' ' '] done
traing:-------Epoch Limit reached
If Range_input_producer removes the parameter Num_epochs=1, the output:
Step:1, Batch data: [' 1 ' 2 ' 3 ' 4 ' 5 ' 6 ']
step:2, batch data: [' 7 ' ' 8 ' ' 9 ' '] '
step:3, Batch Da TA: [' ' ' '] step:4, batch data: [' ' ' ' ' ' ' ' '
]
step:5, batch data: [' 25 ']. ['] "." ' ['] '
step:6, batch data: [' 1 ' 2 ' 3 ' ' 4 ' 5 ' 6 ']
step:7, batch data: [' 7 ' ' 8 ' ' 9 ' ' 10 '] ' One ' '] step:8, batch data: [' ' ' ' ' ' ' ' '
]
step:9, batch data: [' 19 '] 20 ' 21 ' 22 ' 23 ' ' Step:10 ']
, batch data: [' 25 ' 26 ' 27 ' ' 28 ' ' 29 ' ' 30 ']
One thing to note is that there are 35 files in total, batch_size = 6 means that each BATCH contains 6 data, Num_expoches = 5 indicates 5 BATCH, and if num_expoches = 6, a total of 36 data will be reported as the following error:
Invalidargumenterror (above for traceback): expected size[0] in [0, 5], but got 6
[[Node:slice = Slice[index=dt_i NT32, t=dt_string, _device= "/job:localhost/replica:0/task:0/cpu:0"] (Slice/input, Slice/begin/_5, Slice/size)]
The error message means 35/batch_size=5, that is, the num_expoches value can only be between 0 and 5.