Export the TensorFlow network to a single file _tensorflow

Source: Internet
Author: User

Sometimes, we need to export the TensorFlow model to a single file (with both model schema definitions and weights) for easy use elsewhere (such as deploying a network in C + +). Using the Tf.train.write_graph () by default, only the definition of the network (without weights) is exported, and the file that is exported by Tf.train.Saver () is separated from the weight, and therefore other methods are required.

We know that the Graph_def file does not contain the variable value in the network (usually the weight is stored), but it contains the constant value, so if we can convert variable to constant, You can achieve the goal of using a single file to store both network architectures and weights.

We can freeze weights and save networks in the following ways:

Import TensorFlow as TF from
tensorflow.python.framework.graph_util import convert_variables_to_constants

# Construct network
a = tf. Variable ([[3],[4]], Dtype=tf.float32, name= ' a ')
B = tf. Variable (4, Dtype=tf.float32, name= ' B ')
# Be sure to give the output tensor a name ...
output = Tf.add (A, B, name= ' out ')

# converts variable to constant and writes the network to the file with
TF. Session () as Sess:
    Sess.run (Tf.global_variables_initializer ())
    # Here you need to fill in the output tensor name
    graph = Convert_ Variables_to_constants (Sess, Sess.graph_def, ["Out"])
    tf.train.write_graph (graph, '. ', ' GRAPH.PB ', as_text= False)

When you restore your network, you can use the following methods:

Import TensorFlow as TF with
TF. Session () as Sess:
    with open ('./GRAPH.PB ', ' RB ') as f:
        graph_def = tf. Graphdef ()
        graph_def. Parsefromstring (F.read ()) 
        output = Tf.import_graph_def (Graph_def, return_elements=[' out:0 ']) 
        print ( Sess.run (output))

The output results are:

[Array ([7.],
       [8.]], Dtype=float32)]

Can see before the weight is really saved down!!

The problem is that our network needs to have an interface to input custom data. Otherwise, what's the use of this thing? No hurry, of course there is a way.

Import TensorFlow as TF from
tensorflow.python.framework.graph_util import convert_variables_to_constants

a = Tf. Variable ([[3],[4]], Dtype=tf.float32, name= ' a ')
B = tf. Variable (4, Dtype=tf.float32, name= ' B ')
input_tensor = Tf.placeholder (tf.float32, name= ' input ')
output = Tf.add ((a+b), Input_tensor, name= ' out ") with

TF. Session () as Sess:
    Sess.run (Tf.global_variables_initializer ())
    graph = convert_variables_to_constants (sess , Sess.graph_def, ["Out"])
    tf.train.write_graph (graph, '. ', ' GRAPH.PB ', as_text=false)

Re-Save the network to GRAPH.PB with the above code, this time we have an input placeholder, let's look at how to recover the network and enter custom data.

Import TensorFlow as TF with

TF. Session () as Sess:
    with open ('./GRAPH.PB ', ' RB ') as f: 
        graph_def = tf. Graphdef ()
        graph_def. Parsefromstring (F.read ()) 
        output = Tf.import_graph_def (Graph_def, input_map={' input:0 ': 4.}, return_elements=[ ' out:0 '], name= ' a ') 
        print (Sess.run (output))

The output results are:

[Array ([[One.],
       [A.]], Dtype=float32)]

You can see that there is no problem with the results, of course in Input_map where you can replace the new custom placeholder as follows:

Import TensorFlow as tf

new_input = Tf.placeholder (Tf.float32, shape= ()) with

TF. Session () as Sess:
    with open ('./GRAPH.PB ', ' RB ') as f: 
        graph_def = tf. Graphdef ()
        graph_def. Parsefromstring (F.read ()) 
        output = Tf.import_graph_def (Graph_def, input_map={' input:0 ': new_input}, Return_ elements=[' out:0 '], name= ' a ') 
        print (Sess.run (output, feed_dict={new_input:4})

Look at the output, no problem.

[Array ([[One.],
       [A.]], Dtype=float32)]

Another point to note is that when using Tf.train.write_graph to write the network architecture, if the as_text=true, then in the import network, you need to make a little change.

Import TensorFlow as TF from
google.protobuf import Text_format with

TF. Session () as Sess:
    # do not use ' RB ' mode with
    open ('./GRAPH.PB ', ' R ') as F:
        graph_def = tf. Graphdef ()
        # does not use Graph_def. Parsefromstring (F.read ())
        Text_format. Merge (F.read (), graph_def)
        output = Tf.import_graph_def (Graph_def, return_elements=[' out:0 ']) 
        print ( Sess.run (output))
Resources

Is there a example on generate PROTOBUF files holding trained TensorFlow

Contact Us

The content source of this page is from Internet, which doesn't represent Alibaba Cloud's opinion; products and services mentioned on that page don't have any relationship with Alibaba Cloud. If the content of the page makes you feel confusing, please write us an email, we will handle the problem within 5 days after receiving your email.

If you find any instances of plagiarism from the community, please send an email to: info-contact@alibabacloud.com and provide relevant evidence. A staff member will contact you within 5 working days.

A Free Trial That Lets You Build Big!

Start building with 50+ products and up to 12 months usage for Elastic Compute Service

  • Sales Support

    1 on 1 presale consultation

  • After-Sales Support

    24/7 Technical Support 6 Free Tickets per Quarter Faster Response

  • Alibaba Cloud offers highly flexible support services tailored to meet your exact needs.