TensorFlow: A simple Python training save model, Java Restore model approach

Source: Internet
Author: User

Summarize some of the experiences of learning to use tensorflow during this time. The main scenario is to use the Python language to train a simple LR model and save the model in Savedmodel format, then restore the model in Python and the Java language to predict the results.

(1) Training model

Import TensorFlow as TF import numpy as NP #生成训练数据 x = Np.ndarray (Dtype=np.float32, Shape=[4, 2]) x[0] = [1,1] x[1] = [1, 2] x[2] = [1,3] x[3] = [2,4] Print (' ==================== ') print (x) print (x.shape) print (x.dtype) #创建placeHolder作为输入 x_in puts = Tf.placeholder (Tf.float32, Shape=[none, 2]) #输出结果 y_true = Tf.constant ([[2], [4], [5], [9]], Dtype=tf.float32) #单 Layer neural Network, build LR model Linear_model = Tf.layers.Dense (units=1) y_pred = Linear_model (x_inputs) #构建session sess = tf. Session () #保存模型tensorbord可视化结构的writer writer = Tf.summary.FileWriter ("/users/yourname/pythonworkspace/tmp/log", sess.graph) #初始化变量 init = Tf.global_variables_initializer () sess.run (init) #构建损失函数 loss = Tf.losses.mean_squared_error ( Labels=y_true, predictions=y_pred) #梯度下降优化器 optimizer = Tf.train.GradientDescentOptimizer (0.01) train =  Optimizer.minimize (loss) #开始训练模型 print (' ================start================= ') for I in Range (10000): _, Loss_value = Sess.run ((train, loss), feed_dict={x_inputs:x}) if I% 1000 = 0:print (Loss_value) #关闭可视化writer, the visual model can be loaded through Tensorboard--logdir/users/yourname/pythonworkspace/tmp/log writer . Close () #构建savedModel构建器 builder = Tf.saved_model.builder.SavedModelBuilder ("/users/yourname/pythonworkspace/tmp/ Savedmodel/lrmodel ") # X for input tensor, keep_prob for dropout prob tensor = {' input ': inputs Info (x_inputs)} # y is the final desired output tensor outputs = {' Output ': Tf.saved_model.utils.build_tensor_info (y_pred)} signature = TF . SAVED_MODEL.SIGNATURE_DEF_UTILS.BUILD_SIGNATURE_DEF (inputs, outputs, ' test_sig_name ') #保存模型 Builder.add_meta_ Graph_and_variables (Sess, [' Test_saved_model '], {' Test_signature ': Signature}) Builder.save ()

(2) Python loading model

Import TensorFlow as TF with
TF. Session (GRAPH=TF. Graph ()) as Sess:
  #加载模型
  meta_graph_def = Tf.saved_model.loader.load (Sess, [' Test_saved_model '], "/users/ Yourname/pythonworkspace/tmp/savedmodel/lrmodel ")
  #加载模型签名
  signature = Meta_graph_def.signature_def
  Print (signature)
  #从签名中获得张量名
  y_tensor_name = signature[' test_signature '].outputs[' output '].name
  x_ Tensor_name = signature[' test_signature '].inputs[' input '].name
  print (y_tensor_name)
  print (x_tensor_name )
  #还原张量
  y_pred = Sess.graph.get_tensor_by_name (y_tensor_name)
  x_inputs = Sess.graph.get_tensor_by_ Name (x_tensor_name)

  # Predictive Results
  print (Sess.run (y_pred, feed_dict={x_inputs:[[1,6)]})

(3) Java load model
Load TensorFlow Dependency Pack

<dependencies>
        <!--https://mvnrepository.com/artifact/org.tensorflow/tensorflow-->
        < dependency>
            <groupId>org.tensorflow</groupId>
            <artifactid>tensorflow</ artifactid>
            <version>1.8.0-rc0</version>
        </dependency>
        <dependency>
            <groupId>org.tensorflow</groupId>
            <artifactId>proto</artifactId>
            < version>1.8.0-rc1</version>
        </dependency>
    </dependencies>

Load Model Code

Import com.google.protobuf.InvalidProtocolBufferException;
Import Org.tensorflow.SavedModelBundle;
Import Org.tensorflow.Tensor;
Import Org.tensorflow.framework.MetaGraphDef;

Import Org.tensorflow.framework.SignatureDef;


Import java.util.List;
        public class Test {public static void main (string[] args) throws invalidprotocolbufferexception {*/* Load Model * * Savedmodelbundle Savedmodelbundle = Savedmodelbundle.load ("/users/yourname/pythonworkspace/tmp/savedmodel/lrmodel
        "," Test_saved_model ");
        /* Build Predictive tensor */float[][] matrix = new FLOAT[1][2];
        Matrix[0][0] = 1;
        MATRIX[0][1] = 6;
        tensor<float> x = tensor.create (Matrix, float.class); /* Get model signature/Signaturedef sig = Metagraphdef.parsefrom (Savedmodelbundle.metagraphdef ()). Getsignaturedeforthrow ("Test
        _signature ");
        String InputName = Sig.getinputsmap (). Get ("input"). GetName ();
        System.out.println (InputName); String outputname = Sig.getoutputsmap (). GET ("Output"). GetName ();
        System.out.println (Outputname); /* Predictive Model Results */list<tensor<?>> y = savedmodelbundle.session (). Runner (). Feed (InputName, x). Fetch (Outputname
        ). Run ();
        float [[] result = new FLOAT[1][1];
        System.out.println (Y.get (0). DataType ());
        System.out.println (Y.get (0). CopyTo (result));

    System.out.println (Result[0][0]); }
}

Contact Us

The content source of this page is from Internet, which doesn't represent Alibaba Cloud's opinion; products and services mentioned on that page don't have any relationship with Alibaba Cloud. If the content of the page makes you feel confusing, please write us an email, we will handle the problem within 5 days after receiving your email.

If you find any instances of plagiarism from the community, please send an email to: info-contact@alibabacloud.com and provide relevant evidence. A staff member will contact you within 5 working days.

A Free Trial That Lets You Build Big!

Start building with 50+ products and up to 12 months usage for Elastic Compute Service

  • Sales Support

    1 on 1 presale consultation

  • After-Sales Support

    24/7 Technical Support 6 Free Tickets per Quarter Faster Response

  • Alibaba Cloud offers highly flexible support services tailored to meet your exact needs.