tensorflow1.0學習之模型的儲存與恢複(Saver)_python

來源:互聯網
上載者:User
這篇文章主要介紹了tensorflow1.0學習之模型的儲存與恢複(Saver) ,現在分享給大家,也給大家做個參考。一起過來看看吧

將訓練好的模型參數儲存起來,以便以後進行驗證或測試,這是我們經常要做的事情。tf裡面提供模型儲存的是tf.train.Saver()模組。

模型儲存,先要建立一個Saver對象:如

saver=tf.train.Saver()

在建立這個Saver對象的時候,有一個參數我們經常會用到,就是 max_to_keep 參數,這個是用來設定儲存模型的個數,預設為5,即 max_to_keep=5,儲存最近的5個模型。如果你想每訓練一代(epoch)就想儲存一次模型,則可以將 max_to_keep設定為None或者0,如:

saver=tf.train.Saver(max_to_keep=0)

但是這樣做除了多佔用硬碟,並沒有實際多大的用處,因此不推薦。

當然,如果你只想儲存最後一代的模型,則只需要將max_to_keep設定為1即可,即

saver=tf.train.Saver(max_to_keep=1)

建立完saver對象後,就可以儲存訓練好的模型了,如:

saver.save(sess,'ckpt/mnist.ckpt',global_step=step)

第一個參數sess,這個就不用說了。第二個參數設定儲存的路徑和名字,第三個參數將訓練的次數作為尾碼加入到模型名字中。

saver.save(sess, 'my-model', global_step=0) ==> filename: 'my-model-0'
...
saver.save(sess, 'my-model', global_step=1000) ==> filename: 'my-model-1000'

看一個mnist執行個體:

# -*- coding: utf-8 -*-"""Created on Sun Jun 4 10:29:48 2017@author: Administrator"""import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_datamnist = input_data.read_data_sets("MNIST_data/", one_hot=False)x = tf.placeholder(tf.float32, [None, 784])y_=tf.placeholder(tf.int32,[None,])dense1 = tf.layers.dense(inputs=x,            units=1024,            activation=tf.nn.relu,           kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),           kernel_regularizer=tf.nn.l2_loss)dense2= tf.layers.dense(inputs=dense1,            units=512,            activation=tf.nn.relu,           kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),           kernel_regularizer=tf.nn.l2_loss)logits= tf.layers.dense(inputs=dense2,             units=10,             activation=None,            kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),            kernel_regularizer=tf.nn.l2_loss)loss=tf.losses.sparse_softmax_cross_entropy(labels=y_,logits=logits)train_op=tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)correct_prediction = tf.equal(tf.cast(tf.argmax(logits,1),tf.int32), y_)  acc= tf.reduce_mean(tf.cast(correct_prediction, tf.float32))sess=tf.InteractiveSession() sess.run(tf.global_variables_initializer())saver=tf.train.Saver(max_to_keep=1)for i in range(100): batch_xs, batch_ys = mnist.train.next_batch(100) sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys}) val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels}) print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc)) saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)sess.close()

代碼中紅色部分就是儲存模型的代碼,雖然我在每訓練完一代的時候,都進行了儲存,但後一次儲存的模型會覆蓋前一次的,最終只會儲存最後一次。因此我們可以節省時間,將儲存代碼放到迴圈之外(僅適用max_to_keep=1,否則還是需要放在迴圈內).

在實驗中,最後一代可能並不是驗證精度最高的一代,因此我們並不想預設儲存最後一代,而是想儲存驗證精度最高的一代,則加個中間變數和判斷語句就可以了。

saver=tf.train.Saver(max_to_keep=1)max_acc=0for i in range(100): batch_xs, batch_ys = mnist.train.next_batch(100) sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys}) val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels}) print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc)) if val_acc>max_acc:   max_acc=val_acc   saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)sess.close()

如果我們想儲存驗證精度最高的三代,且把每次的驗證精度也隨之儲存下來,則我們可以產生一個txt檔案用於儲存。

saver=tf.train.Saver(max_to_keep=3)max_acc=0f=open('ckpt/acc.txt','w')for i in range(100): batch_xs, batch_ys = mnist.train.next_batch(100) sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys}) val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels}) print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc)) f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n') if val_acc>max_acc:   max_acc=val_acc   saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)f.close()sess.close()

模型的恢複用的是restore()函數,它需要兩個參數restore(sess, save_path),save_path指的是儲存的模型路徑。我們可以使用tf.train.latest_checkpoint()來自動擷取最後一次儲存的模型。如:

model_file=tf.train.latest_checkpoint('ckpt/')saver.restore(sess,model_file)

則程式後半段代碼我們可以改為:

sess=tf.InteractiveSession() sess.run(tf.global_variables_initializer())is_train=Falsesaver=tf.train.Saver(max_to_keep=3)#訓練階段if is_train:  max_acc=0  f=open('ckpt/acc.txt','w')  for i in range(100):   batch_xs, batch_ys = mnist.train.next_batch(100)   sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})   val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})   print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))   f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')   if val_acc>max_acc:     max_acc=val_acc     saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)  f.close()#驗證階段else:  model_file=tf.train.latest_checkpoint('ckpt/')  saver.restore(sess,model_file)  val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})  print('val_loss:%f, val_acc:%f'%(val_loss,val_acc))sess.close()

標紅的地方,就是與儲存、恢複模型相關的代碼。用一個bool型變數is_train來控制訓練和驗證兩個階段。

整個來源程式:

# -*- coding: utf-8 -*-"""Created on Sun Jun 4 10:29:48 2017@author: Administrator"""import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_datamnist = input_data.read_data_sets("MNIST_data/", one_hot=False)x = tf.placeholder(tf.float32, [None, 784])y_=tf.placeholder(tf.int32,[None,])dense1 = tf.layers.dense(inputs=x,            units=1024,            activation=tf.nn.relu,           kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),           kernel_regularizer=tf.nn.l2_loss)dense2= tf.layers.dense(inputs=dense1,            units=512,            activation=tf.nn.relu,           kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),           kernel_regularizer=tf.nn.l2_loss)logits= tf.layers.dense(inputs=dense2,             units=10,             activation=None,            kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),            kernel_regularizer=tf.nn.l2_loss)loss=tf.losses.sparse_softmax_cross_entropy(labels=y_,logits=logits)train_op=tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)correct_prediction = tf.equal(tf.cast(tf.argmax(logits,1),tf.int32), y_)  acc= tf.reduce_mean(tf.cast(correct_prediction, tf.float32))sess=tf.InteractiveSession() sess.run(tf.global_variables_initializer())is_train=Truesaver=tf.train.Saver(max_to_keep=3)#訓練階段if is_train:  max_acc=0  f=open('ckpt/acc.txt','w')  for i in range(100):   batch_xs, batch_ys = mnist.train.next_batch(100)   sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})   val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})   print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))   f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')   if val_acc>max_acc:     max_acc=val_acc     saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)  f.close()#驗證階段else:  model_file=tf.train.latest_checkpoint('ckpt/')  saver.restore(sess,model_file)  val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})  print('val_loss:%f, val_acc:%f'%(val_loss,val_acc))sess.close()

相關文章

聯繫我們

該頁面正文內容均來源於網絡整理,並不代表阿里雲官方的觀點,該頁面所提到的產品和服務也與阿里云無關,如果該頁面內容對您造成了困擾,歡迎寫郵件給我們,收到郵件我們將在5個工作日內處理。

如果您發現本社區中有涉嫌抄襲的內容,歡迎發送郵件至: info-contact@alibabacloud.com 進行舉報並提供相關證據,工作人員會在 5 個工作天內聯絡您,一經查實,本站將立刻刪除涉嫌侵權內容。

A Free Trial That Lets You Build Big!

Start building with 50+ products and up to 12 months usage for Elastic Compute Service

  • Sales Support

    1 on 1 presale consultation

  • After-Sales Support

    24/7 Technical Support 6 Free Tickets per Quarter Faster Response

  • Alibaba Cloud offers highly flexible support services tailored to meet your exact needs.