This article mainly introduces the preservation and restoration of the model of tensorflow1.0 Learning (Saver), and now share to everyone, also to make a reference. Come and see it together.
It is something we often do to save well-trained model parameters for later validation or testing. The Tf.train.Saver () module is stored in the TF that provides the model.
To save the model, first create a Saver object:
Saver=tf.train.saver ()
In the creation of this Saver object, there is a parameter we often use, is the Max_to_keep parameter, which is used to set the number of saved models, the default is 5, that is, max_to_keep=5, save the most recent 5 models. If you want to save the model once per training generation (epoch), you can set Max_to_keep to none or 0, such as:
Saver=tf.train.saver (Max_to_keep=0)
However, this is not recommended, except that it consumes more hard drives and has no practical use.
Of course, if you only want to save the last generation of the model, you only need to set the Max_to_keep to 1, that is
Saver=tf.train.saver (Max_to_keep=1)
Once you have created the Saver object, you can save the trained model, such as:
Saver.save (Sess, ' ckpt/mnist.ckpt ', global_step=step)
The first parameter is Sess, and this goes without saying. The second parameter sets the saved path and name, and the third parameter adds the number of sessions to the model name as a suffix.
Saver.save (Sess, ' My-model ', global_step=0) ==> filename: ' my-model-0 '
...
Saver.save (Sess, ' My-model ', global_step=1000) ==> filename: ' my-model-1000 '
Look at a Mnist instance:
#-*-Coding:utf-8-*-"" "Created on Sun June 4 10:29:48 2017@author:administrator" "" Import TensorFlow as Tffrom Tensorflo W.examples.tutorials.mnist Import input_datamnist = Input_data.read_data_sets ("mnist_data/", one_hot=False) x = Tf.placeholder (Tf.float32, [None, 784]) Y_=tf.placeholder (Tf.int32,[none,]) Dense1 = Tf.layers.dense (Inputs=x, UN its=1024, Activation=tf.nn.relu, Kernel_initializer=tf.truncated_normal_initializer (stddev=0.01), Kernel_regularizer=tf.nn.l2_loss) dense2= tf.layers.dense (inputs=dense1, units=512, activation =tf.nn.relu, Kernel_initializer=tf.truncated_normal_initializer (stddev=0.01), kernel_regularizer=tf.nn . L2_loss) logits= tf.layers.dense (Inputs=dense2, units=10, Activation=none, Kernel_initi Alizer=tf.truncated_normal_initializer (stddev=0.01), Kernel_regularizer=tf.nn.l2_loss) loss=tf.losses.sparse_sof Tmax_cross_entropy (labels=y_,logits=logits) Train_op=tf.train.adamoptimizer (learning_rate=0.001). Minimize (loss) Correct_prediction = tf.equal (tf.cast (Tf.argmax (logits,1), Tf.int32), Y_) acc= Tf.reduce_mean (Tf.cast (correct_prediction, Tf.float32)) sess=tf. InteractiveSession () Sess.run (Tf.global_variables_initializer ()) Saver=tf.train.saver (max_to_keep=1) for I in range ( : Batch_xs, Batch_ys = Mnist.train.next_batch (+) Sess.run (Train_op, Feed_dict={x:batch_xs, Y_: Batch_ys}) Val_ Loss,val_acc=sess.run ([LOSS,ACC], Feed_dict={x:mnist.test.images, Y_: Mnist.test.labels}) Print (' epoch:%d, Val_loss :%f, val_acc:%f '% (I,VAL_LOSS,VAL_ACC)) Saver.save (Sess, ' ckpt/mnist.ckpt ', global_step=i+1) sess.close ()
The red part of the code is the code that saves the model, although I've saved it every time I've trained a generation, but the next saved model overwrites the previous one, and eventually it's saved for the last time. So we can save time by putting the saved code out of the loop (only for max_to_keep=1, otherwise it needs to be in the loop).
In the experiment, the last generation may not be the most accurate generation, so we do not want to save the last generation by default, but want to save the most accurate generation of the validation, then add an intermediate variable and a judgment statement can be.
Saver=tf.train.saver (max_to_keep=1) max_acc=0for i in range: batch_xs, Batch_ys = Mnist.train.next_batch (100) Sess.run (Train_op, Feed_dict={x:batch_xs, Y_: Batch_ys}) Val_loss,val_acc=sess.run ([Loss,acc], feed_dict={x: Mnist.test.images, Y_: Mnist.test.labels}) Print (' epoch:%d, val_loss:%f, val_acc:%f '% (I,VAL_LOSS,VAL_ACC)) if VAL_ACC >MAX_ACC: max_acc=val_acc saver.save (sess, ' ckpt/mnist.ckpt ', global_step=i+1) sess.close ()
If we want to save three generations with the highest accuracy, and save each validation precision, we can generate a TXT file for saving.
Saver=tf.train.saver (max_to_keep=3) max_acc=0f=open (' Ckpt/acc.txt ', ' W ') for I in range: batch_xs, Batch_ys = Mnist.train.next_batch (Sess.run) (Train_op, Feed_dict={x:batch_xs, Y_: Batch_ys}) Val_loss,val_acc=sess.run ([ LOSS,ACC], feed_dict={x:mnist.test.images, Y_: Mnist.test.labels}) Print (' epoch:%d, val_loss:%f, val_acc:%f '% (i,val_ LOSS,VAL_ACC)) F.write (str (i+1) + ', Val_acc: ' +str (VAL_ACC) + ' \ n ') if VAL_ACC>MAX_ACC: MAX_ACC=VAL_ACC Saver.save (Sess, ' ckpt/mnist.ckpt ', global_step=i+1) f.close () Sess.close ()
The recovery of the model is using the restore () function, which requires two parameters, restore (Sess, Save_path), and Save_path refers to the saved model path. We can use Tf.train.latest_checkpoint () to automatically get the last saved model. Such as:
Model_file=tf.train.latest_checkpoint (' ckpt/') Saver.restore (sess,model_file)
Then the second half of the program code we can instead:
SESS=TF. InteractiveSession () Sess.run (Tf.global_variables_initializer ()) Is_train=falsesaver=tf.train.saver (max_to_keep= 3) #训练阶段if Is_train: max_acc=0 f=open (' Ckpt/acc.txt ', ' W ') for I in range: batch_xs, Batch_ys = Mnist.train.next_batch (Sess.run) (Train_op, Feed_dict={x:batch_xs, Y_: Batch_ys}) val_loss,val_acc= Sess.run ([LOSS,ACC], Feed_dict={x:mnist.test.images, Y_: mnist.test.labels}) print (' epoch:%d, val_loss:%f, Val_ Acc:%f '% (I,VAL_LOSS,VAL_ACC)) f.write (str (i+1) + ', Val_acc: ' +str (VAL_ACC) + ' \ n ') if VAL_ACC>MAX_ACC: MAX_ACC=VAL_ACC saver.save (sess, ' ckpt/mnist.ckpt ', global_step=i+1) f.close () #验证阶段else: Model_file=tf.train.latest_checkpoint (' ckpt/') saver.restore (sess,model_file) val_loss,val_acc= Sess.run ([LOSS,ACC], Feed_dict={x:mnist.test.images, Y_: mnist.test.labels}) print (' Val_loss:%f, val_acc:%f ' % (VAL_LOSS,VAL_ACC)) Sess.close ()
The red Place is the code associated with saving and restoring the model. The two stages of training and validation are controlled by a bool variable is_train.
Entire source program:
#-*-Coding:utf-8-*-"" "Created on Sun June 4 10:29:48 2017@author:administrator" "" Import TensorFlow as Tffrom Tensorflo W.examples.tutorials.mnist Import input_datamnist = Input_data.read_data_sets ("mnist_data/", one_hot=False) x = Tf.placeholder (Tf.float32, [None, 784]) Y_=tf.placeholder (Tf.int32,[none,]) Dense1 = Tf.layers.dense (Inputs=x, UN its=1024, Activation=tf.nn.relu, Kernel_initializer=tf.truncated_normal_initializer (stddev=0.01), Kernel_regularizer=tf.nn.l2_loss) dense2= tf.layers.dense (inputs=dense1, units=512, activation =tf.nn.relu, Kernel_initializer=tf.truncated_normal_initializer (stddev=0.01), kernel_regularizer=tf.nn . L2_loss) logits= tf.layers.dense (Inputs=dense2, units=10, Activation=none, Kernel_initi Alizer=tf.truncated_normal_initializer (stddev=0.01), Kernel_regularizer=tf.nn.l2_loss) loss=tf.losses.sparse_sof Tmax_cross_entropy (labels=y_,logits=logits) Train_op=tf.train.adamoptimizer (learning_rate=0.001). Minimize (loss) Correct_prediction = tf.equal (tf.cast (Tf.argmax (logits,1), Tf.int32), Y_) acc= Tf.reduce_mean (Tf.cast (correct_prediction, Tf.float32)) sess=tf. InteractiveSession () Sess.run (Tf.global_variables_initializer ()) Is_train=truesaver=tf.train.saver (max_to_keep=3 ) #训练阶段if is_train:max_acc=0 f=open (' Ckpt/acc.txt ', ' W ') for I in range: batch_xs, Batch_ys = Mnist.train.next_ba TCH (Sess.run) (Train_op, Feed_dict={x:batch_xs, Y_: Batch_ys}) Val_loss,val_acc=sess.run ([LOSS,ACC], feed_dict={x : Mnist.test.images, Y_: Mnist.test.labels}) Print (' epoch:%d, val_loss:%f, val_acc:%f '% (I,VAL_LOSS,VAL_ACC)) F.write ( STR (i+1) + ', Val_acc: ' +str (VAL_ACC) + ' \ n ') if Val_acc>max_acc:max_acc=val_acc saver.save (Sess, ' CKPT/MNIST.CKP T ', global_step=i+1) f.close () #验证阶段else: Model_file=tf.train.latest_checkpoint (' ckpt/') Saver.restore (sess,model_ file) Val_loss,val_acc=sess.run ([LOSS,ACC], feed_dict={x:mnist.test.iMages, Y_: Mnist.test.labels}) Print (' val_loss:%f, val_acc:%f '% (VAL_LOSS,VAL_ACC)) Sess.close ()