Two methods of TensorFlow model saving/loading

Source: Internet
Author: User
TensorFlow model save/load

When we use an algorithmic model on-line, we must first save the trained model. TensorFlow the way to save the model is not the same as Sklearn, Sklearn is straightforward, a sklearn.externals.joblib dump and load method can be saved and loaded using. and TensorFlow because of the graph, operation these concepts, save and load the model slightly more troublesome. first, the basic method

Online Search TensorFlow model preservation, most of the search is the basic method. The save definition variable uses the Saver.save () method to save the load definition variable using the Saver.restore () method to load

such as save Code as follows

Import TensorFlow as TF  
import numpy as np  

W = tf. Variable ([[1,1,1],[2,2,2]],dtype = Tf.float32,name= ' W ')  
B = tf. Variable ([[[0,1,2]],dtype = Tf.float32,name= ' B ')  

init = tf.initialize_all_variables ()  
saver = Tf.train.Saver () with  
TF. Session () as Sess:  
        sess.run (init)  
        Save_path = Saver.save (Sess, "save/model.ckpt")  

load code as follows

Import TensorFlow as TF  
import numpy as np  

W = tf. Variable (Tf.truncated_normal (shape= (2,3)), Dtype = Tf.float32,name= ' W ')  
B = tf. Variable (Tf.truncated_normal (shape= (1,3)), Dtype = Tf.float32,name= ' B ')  

saver = Tf.train.Saver () with  
TF. Session () as Sess:  
        saver.restore (Sess, "save/model.ckpt")  

The inconvenient thing about this method is that when using the model, the structure of the model must be redefined again, and then the value of the variable that corresponds to the name is loaded. But most of the time we would prefer to be able to read a file and then use the model directly, rather than redefining the model again. So you need to use a different approach. Ii. ways to redefine the network structure Tf.train.import_meta_graph

Import_meta_graph (
    meta_graph_or_file,
    clear_devices=false,
    import_scope=none,
    **kwargs
)

This method loads all the nodes of the saved graph from the file into the current default graph and returns a saver. In other words, when we save, we save the value of the variable, in fact, the corresponding graph of the various nodes are saved, so the structure of the model is also preserved.

For example, if we want to save Y to calculate the final prediction result, we should add it to collection during the training phase. The specific code is saved as follows

# # # Definition Model input_x = Tf.placeholder (Tf.float32, shape= (None, In_dim), name= ' input_x ') input_y = Tf.placeholder (Tf.float32, Shape= (None, Out_dim), name= ' input_y ') w1 = tf. Variable (Tf.truncated_normal ([In_dim, H1_dim], stddev=0.1), name= ' W1 ') B1 = tf. Variable (Tf.zeros ([H1_dim]), Name= ' B1 ') w2 = tf. Variable (Tf.zeros ([H1_dim, Out_dim]), name= ' W2 ') b2 = tf. Variable (Tf.zeros ([Out_dim]), name= ' b2 ') Keep_prob = Tf.placeholder (Tf.float32, name= ' keep_prob ') Hidden1 = Tf.nn.relu (Tf.matmul (self.input_x, W1) + B1) Hidden1_drop = Tf.nn.dropout (Hidden1, Self.keep_prob) # # # define forecast Target y = Tf.nn.softmax (tf.m Atmul (Hidden1_drop, W2) + b2) # Create Saver saver = Tf.train.Saver (...) # If you need to save Y to use variables (' Pred_network ', y) sess = tf. Session () For step in Xrange (1000000): Sess.run (train_op) If step% 1000 = = 0: # Save checkpoint, also export a me by default
        Ta_graph # graph named ' My-model-{global_step}.meta '. Saver.save (Sess, ' My-model ', global_step=step)
Gta5-In
With TF. Session () as sess:
  new_saver = tf.train.import_meta_graph (' My-save-dir/my-model-10000.meta ')
  new_ Saver.restore (Sess, ' my-save-dir/my-model-10000 ')
  # tf.get_collection () returns a list. But as long as the first argument can be
  y = tf.get_ Collection (' pred_network ') [0]

  graph = tf.get_default_graph ()

  # because there is placeholder in Y, so Sess.run (y) It is also necessary to populate these placeholder with the actual samples to be predicted and the corresponding parameters, which need to be obtained by graph's Get_operation_by_name method.
  input_x = graph.get_operation_by_name (' input_x '). outputs[0]
  keep_prob = graph.get_operation_by_name (' Keep _prob '). Outputs[0]

  # Use Y to predict  
  sess.run (y, feed_dict={input_x: .....,  keep_prob:1.0})

here are two points to keep in mind:
First, Saver.restore () fill in the file name, because at saver.save time, each checkpoint will save three files, such as
My-model-10000.meta, My-model-10000.index, my-model-10000.data-00000-of-00001
The meta file name is filled in when import_meta_graph. We know that the weights are stored in the my-model-10000.data-00000-of-00001 file, but if you fill in this file name in the Restore method, you will get an error, you should fill in the prefix , This prefix can be obtained using the Tf.train.latest_checkpoint (Checkpoint_dir) method.

Second, the model Y is useful to placeholder, in the Sess.run () when the feed corresponding to the data, so also according to the specific placeholder name, from graph using Get_operation_by_name method obtained.

Contact Us

The content source of this page is from Internet, which doesn't represent Alibaba Cloud's opinion; products and services mentioned on that page don't have any relationship with Alibaba Cloud. If the content of the page makes you feel confusing, please write us an email, we will handle the problem within 5 days after receiving your email.

If you find any instances of plagiarism from the community, please send an email to: info-contact@alibabacloud.com and provide relevant evidence. A staff member will contact you within 5 working days.

A Free Trial That Lets You Build Big!

Start building with 50+ products and up to 12 months usage for Elastic Compute Service

  • Sales Support

    1 on 1 presale consultation

  • After-Sales Support

    24/7 Technical Support 6 Free Tickets per Quarter Faster Response

  • Alibaba Cloud offers highly flexible support services tailored to meet your exact needs.