Summarize some of the experiences of learning to use tensorflow during this time. The main scenario is to use the Python language to train a simple LR model and save the model in Savedmodel format, then restore the model in Python and the Java language to predict the results.
(1) Training model
Import TensorFlow as TF import numpy as NP #生成训练数据 x = Np.ndarray (Dtype=np.float32, Shape=[4, 2]) x[0] = [1,1] x[1] = [1, 2] x[2] = [1,3] x[3] = [2,4] Print (' ==================== ') print (x) print (x.shape) print (x.dtype) #创建placeHolder作为输入 x_in puts = Tf.placeholder (Tf.float32, Shape=[none, 2]) #输出结果 y_true = Tf.constant ([[2], [4], [5], [9]], Dtype=tf.float32) #单 Layer neural Network, build LR model Linear_model = Tf.layers.Dense (units=1) y_pred = Linear_model (x_inputs) #构建session sess = tf. Session () #保存模型tensorbord可视化结构的writer writer = Tf.summary.FileWriter ("/users/yourname/pythonworkspace/tmp/log", sess.graph) #初始化变量 init = Tf.global_variables_initializer () sess.run (init) #构建损失函数 loss = Tf.losses.mean_squared_error ( Labels=y_true, predictions=y_pred) #梯度下降优化器 optimizer = Tf.train.GradientDescentOptimizer (0.01) train = Optimizer.minimize (loss) #开始训练模型 print (' ================start================= ') for I in Range (10000): _, Loss_value = Sess.run ((train, loss), feed_dict={x_inputs:x}) if I% 1000 = 0:print (Loss_value) #关闭可视化writer, the visual model can be loaded through Tensorboard--logdir/users/yourname/pythonworkspace/tmp/log writer . Close () #构建savedModel构建器 builder = Tf.saved_model.builder.SavedModelBuilder ("/users/yourname/pythonworkspace/tmp/ Savedmodel/lrmodel ") # X for input tensor, keep_prob for dropout prob tensor = {' input ': inputs Info (x_inputs)} # y is the final desired output tensor outputs = {' Output ': Tf.saved_model.utils.build_tensor_info (y_pred)} signature = TF . SAVED_MODEL.SIGNATURE_DEF_UTILS.BUILD_SIGNATURE_DEF (inputs, outputs, ' test_sig_name ') #保存模型 Builder.add_meta_ Graph_and_variables (Sess, [' Test_saved_model '], {' Test_signature ': Signature}) Builder.save ()
(2) Python loading model
Import TensorFlow as TF with
TF. Session (GRAPH=TF. Graph ()) as Sess:
#加载模型
meta_graph_def = Tf.saved_model.loader.load (Sess, [' Test_saved_model '], "/users/ Yourname/pythonworkspace/tmp/savedmodel/lrmodel ")
#加载模型签名
signature = Meta_graph_def.signature_def
Print (signature)
#从签名中获得张量名
y_tensor_name = signature[' test_signature '].outputs[' output '].name
x_ Tensor_name = signature[' test_signature '].inputs[' input '].name
print (y_tensor_name)
print (x_tensor_name )
#还原张量
y_pred = Sess.graph.get_tensor_by_name (y_tensor_name)
x_inputs = Sess.graph.get_tensor_by_ Name (x_tensor_name)
# Predictive Results
print (Sess.run (y_pred, feed_dict={x_inputs:[[1,6)]})
(3) Java load model
Load TensorFlow Dependency Pack
<dependencies>
<!--https://mvnrepository.com/artifact/org.tensorflow/tensorflow-->
< dependency>
<groupId>org.tensorflow</groupId>
<artifactid>tensorflow</ artifactid>
<version>1.8.0-rc0</version>
</dependency>
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>proto</artifactId>
< version>1.8.0-rc1</version>
</dependency>
</dependencies>
Load Model Code
Import com.google.protobuf.InvalidProtocolBufferException;
Import Org.tensorflow.SavedModelBundle;
Import Org.tensorflow.Tensor;
Import Org.tensorflow.framework.MetaGraphDef;
Import Org.tensorflow.framework.SignatureDef;
Import java.util.List;
public class Test {public static void main (string[] args) throws invalidprotocolbufferexception {*/* Load Model * * Savedmodelbundle Savedmodelbundle = Savedmodelbundle.load ("/users/yourname/pythonworkspace/tmp/savedmodel/lrmodel
"," Test_saved_model ");
/* Build Predictive tensor */float[][] matrix = new FLOAT[1][2];
Matrix[0][0] = 1;
MATRIX[0][1] = 6;
tensor<float> x = tensor.create (Matrix, float.class); /* Get model signature/Signaturedef sig = Metagraphdef.parsefrom (Savedmodelbundle.metagraphdef ()). Getsignaturedeforthrow ("Test
_signature ");
String InputName = Sig.getinputsmap (). Get ("input"). GetName ();
System.out.println (InputName); String outputname = Sig.getoutputsmap (). GET ("Output"). GetName ();
System.out.println (Outputname); /* Predictive Model Results */list<tensor<?>> y = savedmodelbundle.session (). Runner (). Feed (InputName, x). Fetch (Outputname
). Run ();
float [[] result = new FLOAT[1][1];
System.out.println (Y.get (0). DataType ());
System.out.println (Y.get (0). CopyTo (result));
System.out.println (Result[0][0]); }
}