The TensorFlow model is used to store/load the tensorflow model.

Source: Internet
Author: User

The TensorFlow model is used to store/load the tensorflow model.

TensorFlow model saving/loading

When we use an algorithm model online, we must first save the trained model. Tensorflow saves models in a different way than sklearn. sklearn is very direct. the dump and load methods of sklearn. externals. joblib can be saved and loaded. Tensorflow has the concepts of graph and operation, which makes it a little difficult to store and load models.

I. Basic Methods

Search for the tensorflow model on the Internet and save it. Most of the searched results are basic methods. That is

Save

  • Define Variables
  • Save using the saver. save () method

Load

  • Define Variables
  • Load Using the saver. restore () method

For exampleSaveThe Code is as follows:

import tensorflow as tf import numpy as np W = tf.Variable([[1,1,1],[2,2,2]],dtype = tf.float32,name='w') b = tf.Variable([[0,1,2]],dtype = tf.float32,name='b') init = tf.initialize_all_variables() saver = tf.train.Saver() with tf.Session() as sess:   sess.run(init)   save_path = saver.save(sess,"save/model.ckpt") 

LoadThe Code is as follows:

import tensorflow as tf import numpy as np W = tf.Variable(tf.truncated_normal(shape=(2,3)),dtype = tf.float32,name='w') b = tf.Variable(tf.truncated_normal(shape=(1,3)),dtype = tf.float32,name='b') saver = tf.train.Saver() with tf.Session() as sess:   saver.restore(sess,"save/model.ckpt") 

This method is inconvenient because, when using a model, you must redefine the structure of the Model and load the value of the variable corresponding to the name. However, in many cases, we prefer to read a file and then use the model directly, rather than re-defining the model. Therefore, another method is required.

2. You do not need to redefine the network structure.

tf.train.import_meta_graphimport_meta_graph( meta_graph_or_file, clear_devices=False, import_scope=None, **kwargs)

This method can load all the nodes of the saved graph to the current default graph and return a saver. That is to say, when we save, in addition to saving the value of the variable, there are actually various nodes in the corresponding graph, so the structure of the model is also saved.

For example, if we want to save the y of the final prediction result, we should add it to the collection during the training phase. The Code is as follows:

Save

### Define the model input_x = tf. placeholder (tf. float32, shape = (None, in_dim), name = 'input _ x') input_y = tf. placeholder (tf. float32, shape = (None, out_dim), name = 'input _ y') w1 = tf. variable (tf. truncated_normal ([in_dim, h1_dim], stddev = 0.1), name = 'w1') b1 = tf. variable (tf. zeros ([h1_dim]), name = 'b1 ') w2 = tf. variable (tf. zeros ([h1_dim, out_dim]), name = 'w2 ') b2 = tf. variable (tf. zeros ([out_dim]), name = 'b2') keep_prob = tf. placeholder (tf. float32, name = 'keep _ prob') hidden1 = tf. nn. relu (tf. matmul (self. input_x, w1) + b1) hidden1_drop = tf. nn. dropout (hidden1, self. keep_prob) ### define prediction target y = tf. nn. softmax (tf. matmul (hidden1_drop, w2) + b2) # create saversaver = tf. train. saver (... variables ...) # If you want to save y to use tf in prediction. add_to_collection ('pred _ network', y) sess = tf. session () for step in xrange (1000000): sess. run (train_op) if step % 1000 = 0: # Save the checkpoint and export a meta_graph # graph named 'my-model-{global_step} by default }. meta '. saver. save (sess, 'My-model', global_step = step)

Load

With tf. session () as sess: new_saver = tf. train. import_meta_graph ('My-save-dir/my-model-10000.meta ') new_saver.restore (sess, 'My-save-dir/my-model-100') # tf. get_collection () returns a list. however, only the first parameter is required here. get_collection ('pred _ Network') [0] graph = tf. get_default_graph () # sess because y contains placeholder. during run (y), you also need to fill in the placeholder with the actual sample to be predicted and corresponding parameters, which need to be obtained through the get_operation_by_name method of graph. Input_x = graph. get_operation_by_name ('input _ x '). outputs [0] keep_prob = graph. get_operation_by_name ('Keep _ prob '). outputs [0] # use y to predict sess. run (y, feed_dict = {input_x :...., keep_prob: 1.0 })

Note the following two points:

1. The file name entered during saver. restore (), because during saver. save, each checkpoint stores three files, as shown in figure
My-model-10000.meta, my-model-10000.index, my-model-10000.data-00000-of-00001
In import_meta_graph, the meta file name is entered. We know that all weights are saved in the my-model-10000.data-00000-of-00001 file, but if you fill in this file name in the restore method, an error will be reported, which should be the prefix, this prefix can use tf. train. obtain the latest_checkpoint (checkpoint_dir) method.

2. placeholder is used in model y, and the corresponding data must be fed during sess. run (). Therefore, you must use the get_operation_by_name method in graph based on the specific placeholder name.

The above is all the content of this article. I hope it will be helpful for your learning and support for helping customers.

Contact Us

The content source of this page is from Internet, which doesn't represent Alibaba Cloud's opinion; products and services mentioned on that page don't have any relationship with Alibaba Cloud. If the content of the page makes you feel confusing, please write us an email, we will handle the problem within 5 days after receiving your email.

If you find any instances of plagiarism from the community, please send an email to: info-contact@alibabacloud.com and provide relevant evidence. A staff member will contact you within 5 working days.

A Free Trial That Lets You Build Big!

Start building with 50+ products and up to 12 months usage for Elastic Compute Service

  • Sales Support

    1 on 1 presale consultation

  • After-Sales Support

    24/7 Technical Support 6 Free Tickets per Quarter Faster Response

  • Alibaba Cloud offers highly flexible support services tailored to meet your exact needs.