總結一下這段時間學習使用tensorflow的一些經驗。主要應用情境是,使用python語言訓練一個簡單的LR模型,並且將模型以savedModel格式儲存模型,然後以python和java語言還原模型,預測結果。
(1)訓練模型
import tensorflow as tfimport numpy as np#產生訓練資料x = np.ndarray(dtype=np.float32, shape=[4, 2])x[0] = [1,1]x[1] = [1,2]x[2] = [1,3]x[3] = [2,4]print('====================')print(x)print(x.shape)print(x.dtype)#建立placeHolder作為輸入x_inputs = tf.placeholder(tf.float32, shape=[None, 2])#輸出結果y_true = tf.constant([[2], [4], [5], [9]], dtype=tf.float32)#單層神經網路,搭建LR模型linear_model = tf.layers.Dense(units=1)y_pred = linear_model(x_inputs)#構建sessionsess = tf.Session()#儲存模型tensorbord可視化結構的writerwriter = tf.summary.FileWriter("/Users/yourName/pythonworkspace/tmp/log", sess.graph)#初始設定變數init = tf.global_variables_initializer()sess.run(init)#構建損失函數loss = tf.losses.mean_squared_error(labels=y_true, predictions=y_pred)#梯度下降最佳化器optimizer = tf.train.GradientDescentOptimizer(0.01)train = optimizer.minimize(loss)#開始訓練模型print('================start=================')for i in range(10000): _, loss_value = sess.run((train, loss), feed_dict={x_inputs:x}) if i % 1000 == 0: print(loss_value)#關閉可視化writer,可以通過tensorboard --logdir /Users/yourName/pythonworkspace/tmp/log載入可視化模型writer.close()#構建savedModel構建器builder = tf.saved_model.builder.SavedModelBuilder("/Users/yourName/pythonworkspace/tmp/savedModel/lrmodel")# x 為輸入tensor, keep_prob為dropout的prob tensorinputs = {'input': tf.saved_model.utils.build_tensor_info(x_inputs)}# y 為最終需要的輸出結果tensoroutputs = {'output': tf.saved_model.utils.build_tensor_info(y_pred)}signature = tf.saved_model.signature_def_utils.build_signature_def(inputs, outputs, 'test_sig_name')#儲存模型builder.add_meta_graph_and_variables(sess, ['test_saved_model'], {'test_signature':signature})builder.save()
(2)python 載入模型
import tensorflow as tfwith tf.Session(graph=tf.Graph()) as sess: #載入模型 meta_graph_def = tf.saved_model.loader.load(sess, ['test_saved_model'], "/Users/yourName/pythonworkspace/tmp/savedModel/lrmodel") #載入模型簽名 signature = meta_graph_def.signature_def print(signature) #從簽名中獲得張量名 y_tensor_name = signature['test_signature'].outputs['output'].name x_tensor_name = signature['test_signature'].inputs['input'].name print(y_tensor_name) print(x_tensor_name) #還原張量 y_pred = sess.graph.get_tensor_by_name(y_tensor_name) x_inputs = sess.graph.get_tensor_by_name(x_tensor_name) # 預測結果 print(sess.run(y_pred, feed_dict={x_inputs:[[1,6]]}))
(3)java 載入模型
載入tensorflow依賴包
<dependencies> <!-- https://mvnrepository.com/artifact/org.tensorflow/tensorflow --> <dependency> <groupId>org.tensorflow</groupId> <artifactId>tensorflow</artifactId> <version>1.8.0-rc0</version> </dependency> <dependency> <groupId>org.tensorflow</groupId> <artifactId>proto</artifactId> <version>1.8.0-rc1</version> </dependency> </dependencies>
載入模型代碼
import com.google.protobuf.InvalidProtocolBufferException;import org.tensorflow.SavedModelBundle;import org.tensorflow.Tensor;import org.tensorflow.framework.MetaGraphDef;import org.tensorflow.framework.SignatureDef;import java.util.List;public class Test { public static void main(String[] args) throws InvalidProtocolBufferException { /*載入模型 */ SavedModelBundle savedModelBundle = SavedModelBundle.load("/Users/yourName/pythonworkspace/tmp/savedModel/lrmodel", "test_saved_model"); /*構建預測張量*/ float[][] matrix = new float[1][2]; matrix[0][0] = 1; matrix[0][1] = 6; Tensor<Float> x = Tensor.create(matrix, Float.class); /*擷取模型簽名*/ SignatureDef sig = MetaGraphDef.parseFrom(savedModelBundle.metaGraphDef()).getSignatureDefOrThrow("test_signature"); String inputName = sig.getInputsMap().get("input").getName(); System.out.println(inputName); String outputName = sig.getOutputsMap().get("output").getName(); System.out.println(outputName); /*預測模型結果*/ List<Tensor<?>> y = savedModelBundle.session().runner().feed(inputName, x).fetch(outputName).run(); float [][] result = new float[1][1]; System.out.println(y.get(0).dataType()); System.out.println(y.get(0).copyTo(result)); System.out.println(result[0][0]); }}