tensorflow:一個簡單的python訓練儲存模型,java還原模型方法

來源:互聯網
上載者:User

總結一下這段時間學習使用tensorflow的一些經驗。主要應用情境是,使用python語言訓練一個簡單的LR模型,並且將模型以savedModel格式儲存模型,然後以python和java語言還原模型,預測結果。

(1)訓練模型

import tensorflow as tfimport numpy as np#產生訓練資料x = np.ndarray(dtype=np.float32, shape=[4, 2])x[0] = [1,1]x[1] = [1,2]x[2] = [1,3]x[3] = [2,4]print('====================')print(x)print(x.shape)print(x.dtype)#建立placeHolder作為輸入x_inputs = tf.placeholder(tf.float32, shape=[None, 2])#輸出結果y_true = tf.constant([[2], [4], [5], [9]], dtype=tf.float32)#單層神經網路,搭建LR模型linear_model = tf.layers.Dense(units=1)y_pred = linear_model(x_inputs)#構建sessionsess = tf.Session()#儲存模型tensorbord可視化結構的writerwriter = tf.summary.FileWriter("/Users/yourName/pythonworkspace/tmp/log", sess.graph)#初始設定變數init = tf.global_variables_initializer()sess.run(init)#構建損失函數loss = tf.losses.mean_squared_error(labels=y_true, predictions=y_pred)#梯度下降最佳化器optimizer = tf.train.GradientDescentOptimizer(0.01)train = optimizer.minimize(loss)#開始訓練模型print('================start=================')for i in range(10000):    _, loss_value = sess.run((train, loss), feed_dict={x_inputs:x})    if i % 1000 == 0:        print(loss_value)#關閉可視化writer,可以通過tensorboard --logdir /Users/yourName/pythonworkspace/tmp/log載入可視化模型writer.close()#構建savedModel構建器builder = tf.saved_model.builder.SavedModelBuilder("/Users/yourName/pythonworkspace/tmp/savedModel/lrmodel")# x 為輸入tensor, keep_prob為dropout的prob tensorinputs = {'input': tf.saved_model.utils.build_tensor_info(x_inputs)}# y 為最終需要的輸出結果tensoroutputs = {'output': tf.saved_model.utils.build_tensor_info(y_pred)}signature = tf.saved_model.signature_def_utils.build_signature_def(inputs, outputs, 'test_sig_name')#儲存模型builder.add_meta_graph_and_variables(sess, ['test_saved_model'], {'test_signature':signature})builder.save()

(2)python 載入模型

import tensorflow as tfwith tf.Session(graph=tf.Graph()) as sess:  #載入模型  meta_graph_def = tf.saved_model.loader.load(sess, ['test_saved_model'], "/Users/yourName/pythonworkspace/tmp/savedModel/lrmodel")  #載入模型簽名  signature = meta_graph_def.signature_def  print(signature)  #從簽名中獲得張量名  y_tensor_name = signature['test_signature'].outputs['output'].name  x_tensor_name = signature['test_signature'].inputs['input'].name  print(y_tensor_name)  print(x_tensor_name)  #還原張量  y_pred = sess.graph.get_tensor_by_name(y_tensor_name)  x_inputs = sess.graph.get_tensor_by_name(x_tensor_name)  # 預測結果  print(sess.run(y_pred, feed_dict={x_inputs:[[1,6]]}))

(3)java 載入模型
載入tensorflow依賴包

 <dependencies>        <!-- https://mvnrepository.com/artifact/org.tensorflow/tensorflow -->        <dependency>            <groupId>org.tensorflow</groupId>            <artifactId>tensorflow</artifactId>            <version>1.8.0-rc0</version>        </dependency>        <dependency>            <groupId>org.tensorflow</groupId>            <artifactId>proto</artifactId>            <version>1.8.0-rc1</version>        </dependency>    </dependencies>

載入模型代碼

import com.google.protobuf.InvalidProtocolBufferException;import org.tensorflow.SavedModelBundle;import org.tensorflow.Tensor;import org.tensorflow.framework.MetaGraphDef;import org.tensorflow.framework.SignatureDef;import java.util.List;public class Test {    public static void main(String[] args) throws InvalidProtocolBufferException {        /*載入模型 */        SavedModelBundle savedModelBundle = SavedModelBundle.load("/Users/yourName/pythonworkspace/tmp/savedModel/lrmodel", "test_saved_model");        /*構建預測張量*/        float[][] matrix = new float[1][2];        matrix[0][0] = 1;        matrix[0][1] = 6;        Tensor<Float> x = Tensor.create(matrix, Float.class);        /*擷取模型簽名*/        SignatureDef sig = MetaGraphDef.parseFrom(savedModelBundle.metaGraphDef()).getSignatureDefOrThrow("test_signature");        String inputName = sig.getInputsMap().get("input").getName();        System.out.println(inputName);        String outputName = sig.getOutputsMap().get("output").getName();        System.out.println(outputName);        /*預測模型結果*/        List<Tensor<?>> y = savedModelBundle.session().runner().feed(inputName, x).fetch(outputName).run();        float [][] result = new float[1][1];        System.out.println(y.get(0).dataType());        System.out.println(y.get(0).copyTo(result));        System.out.println(result[0][0]);    }}

聯繫我們

該頁面正文內容均來源於網絡整理,並不代表阿里雲官方的觀點,該頁面所提到的產品和服務也與阿里云無關,如果該頁面內容對您造成了困擾,歡迎寫郵件給我們,收到郵件我們將在5個工作日內處理。

如果您發現本社區中有涉嫌抄襲的內容,歡迎發送郵件至: info-contact@alibabacloud.com 進行舉報並提供相關證據,工作人員會在 5 個工作天內聯絡您,一經查實,本站將立刻刪除涉嫌侵權內容。

A Free Trial That Lets You Build Big!

Start building with 50+ products and up to 12 months usage for Elastic Compute Service

  • Sales Support

    1 on 1 presale consultation

  • After-Sales Support

    24/7 Technical Support 6 Free Tickets per Quarter Faster Response

  • Alibaba Cloud offers highly flexible support services tailored to meet your exact needs.