TensorFlow model save/load
When we use an algorithmic model on-line, we must first save the trained model. TensorFlow the way to save the model is not the same as Sklearn, Sklearn is straightforward, a sklearn.externals.joblib dump and load method can be saved and loaded using. and TensorFlow because of the graph, operation these concepts, save and load the model slightly more troublesome. first, the basic method
Online Search TensorFlow model preservation, most of the search is the basic method. The save definition variable uses the Saver.save () method to save the load definition variable using the Saver.restore () method to load
such as save Code as follows
Import TensorFlow as TF
import numpy as np
W = tf. Variable ([[1,1,1],[2,2,2]],dtype = Tf.float32,name= ' W ')
B = tf. Variable ([[[0,1,2]],dtype = Tf.float32,name= ' B ')
init = tf.initialize_all_variables ()
saver = Tf.train.Saver () with
TF. Session () as Sess:
sess.run (init)
Save_path = Saver.save (Sess, "save/model.ckpt")
load code as follows
Import TensorFlow as TF
import numpy as np
W = tf. Variable (Tf.truncated_normal (shape= (2,3)), Dtype = Tf.float32,name= ' W ')
B = tf. Variable (Tf.truncated_normal (shape= (1,3)), Dtype = Tf.float32,name= ' B ')
saver = Tf.train.Saver () with
TF. Session () as Sess:
saver.restore (Sess, "save/model.ckpt")
The inconvenient thing about this method is that when using the model, the structure of the model must be redefined again, and then the value of the variable that corresponds to the name is loaded. But most of the time we would prefer to be able to read a file and then use the model directly, rather than redefining the model again. So you need to use a different approach. Ii. ways to redefine the network structure Tf.train.import_meta_graph
Import_meta_graph (
meta_graph_or_file,
clear_devices=false,
import_scope=none,
**kwargs
)
This method loads all the nodes of the saved graph from the file into the current default graph and returns a saver. In other words, when we save, we save the value of the variable, in fact, the corresponding graph of the various nodes are saved, so the structure of the model is also preserved.
For example, if we want to save Y to calculate the final prediction result, we should add it to collection during the training phase. The specific code is saved as follows
# # # Definition Model input_x = Tf.placeholder (Tf.float32, shape= (None, In_dim), name= ' input_x ') input_y = Tf.placeholder (Tf.float32, Shape= (None, Out_dim), name= ' input_y ') w1 = tf. Variable (Tf.truncated_normal ([In_dim, H1_dim], stddev=0.1), name= ' W1 ') B1 = tf. Variable (Tf.zeros ([H1_dim]), Name= ' B1 ') w2 = tf. Variable (Tf.zeros ([H1_dim, Out_dim]), name= ' W2 ') b2 = tf. Variable (Tf.zeros ([Out_dim]), name= ' b2 ') Keep_prob = Tf.placeholder (Tf.float32, name= ' keep_prob ') Hidden1 = Tf.nn.relu (Tf.matmul (self.input_x, W1) + B1) Hidden1_drop = Tf.nn.dropout (Hidden1, Self.keep_prob) # # # define forecast Target y = Tf.nn.softmax (tf.m Atmul (Hidden1_drop, W2) + b2) # Create Saver saver = Tf.train.Saver (...) # If you need to save Y to use variables (' Pred_network ', y) sess = tf. Session () For step in Xrange (1000000): Sess.run (train_op) If step% 1000 = = 0: # Save checkpoint, also export a me by default
Ta_graph # graph named ' My-model-{global_step}.meta '. Saver.save (Sess, ' My-model ', global_step=step)
Gta5-In
With TF. Session () as sess:
new_saver = tf.train.import_meta_graph (' My-save-dir/my-model-10000.meta ')
new_ Saver.restore (Sess, ' my-save-dir/my-model-10000 ')
# tf.get_collection () returns a list. But as long as the first argument can be
y = tf.get_ Collection (' pred_network ') [0]
graph = tf.get_default_graph ()
# because there is placeholder in Y, so Sess.run (y) It is also necessary to populate these placeholder with the actual samples to be predicted and the corresponding parameters, which need to be obtained by graph's Get_operation_by_name method.
input_x = graph.get_operation_by_name (' input_x '). outputs[0]
keep_prob = graph.get_operation_by_name (' Keep _prob '). Outputs[0]
# Use Y to predict
sess.run (y, feed_dict={input_x: ....., keep_prob:1.0})
here are two points to keep in mind:
First, Saver.restore () fill in the file name, because at saver.save time, each checkpoint will save three files, such as
My-model-10000.meta, My-model-10000.index, my-model-10000.data-00000-of-00001
The meta file name is filled in when import_meta_graph. We know that the weights are stored in the my-model-10000.data-00000-of-00001 file, but if you fill in this file name in the Restore method, you will get an error, you should fill in the prefix , This prefix can be obtained using the Tf.train.latest_checkpoint (Checkpoint_dir) method.
Second, the model Y is useful to placeholder, in the Sess.run () when the feed corresponding to the data, so also according to the specific placeholder name, from graph using Get_operation_by_name method obtained.