TensorFlow 基本用法樣本

來源:互聯網
上載者:User

TensorFlow 基本用法樣本

本篇內容基於 Python3 TensorFlow 1.4 版本。本節內容 本節通過最簡單的樣本 —— 平面擬合來說明 TensorFlow 的基本用法。

構造資料 TensorFlow 的引入方式是:

import tensorflow as tf

接下來我們構造一些隨機的三維資料,然後用 TensorFlow 找到平面去擬合它,首先我們用 Numpy 產生隨機三維點,其中變數 x 代表三維點的 (x, y) 座標,是一個 2×100 的矩陣,即 100 個 (x, y),然後變數 y 代表三位點的 z 座標,我們用 Numpy 來產生這些隨機的點:

import numpy as np
x_data = np.float32(np.random.rand(2, 100))
y_data = np.dot([0.300, 0.200], x_data) + 0.400

print(x_data)
print(y_data)

這裡利用 Numpy 的 random 模組的 rand() 方法產生了 2×100 的隨機矩陣,這樣就產生了 100 個 (x, y) 座標,然後用了一個 dot() 方法算了矩陣乘法,用了一個長度為 2 的向量跟此矩陣相乘,得到一個長度為 100 的向量,然後再加上一個常量,得到 z 座標,輸出結果範例如下:

[[ 0.97232962  0.08897641  0.54844421  0.5877986  0.5121088  0.64716059
  0.22353953  0.18406206  0.16782761  0.97569454  0.65686035  0.75569868
  0.35698661  0.43332314  0.41185728  0.24801297  0.50098598  0.12025958
  0.40650111  0.51486945  0.19292323  0.03679928  0.56501174  0.5321334
  0.71044683  0.00318134  0.76611853  0.42602748  0.33002195  0.04414672
  0.73208278  0.62182301  0.49471655  0.8116194  0.86148429  0.48835048
  0.69902027  0.14901569  0.18737803  0.66826463  0.43462989  0.35768151
  0.79315376  0.0400687  0.76952982  0.12236254  0.61519378  0.92795062
  0.84952474  0.16663995  0.13729768  0.50603199  0.38752931  0.39529857
  0.29228279  0.09773371  0.43220878  0.2603009  0.14576958  0.21881725
  0.64888018  0.41048348  0.27641159  0.61700606  0.49728736  0.75936913
  0.04028837  0.88986284  0.84112513  0.34227493  0.69162005  0.89058989
  0.39744586  0.85080278  0.37685293  0.80529863  0.31220895  0.50500977
  0.95800418  0.43696108  0.04143282  0.05169986  0.33503434  0.1671818
  0.10234453  0.31241918  0.23630807  0.37890589  0.63020509  0.78184551
  0.87924582  0.99288088  0.30762389  0.43499199  0.53140771  0.43461791
  0.23833922  0.08681628  0.74615192  0.25835371]
 [ 0.8174957  0.26717573  0.23811154  0.02851068  0.9627012  0.36802396
  0.50543582  0.29964805  0.44869211  0.23191817  0.77344608  0.36636299
  0.56170034  0.37465382  0.00471885  0.19509546  0.49715847  0.15201907
  0.5642485  0.70218688  0.6031307  0.4705168  0.98698962  0.865367
  0.36558965  0.72073907  0.83386165  0.29963031  0.72276717  0.98171854
  0.30932376  0.52615297  0.35522953  0.13186514  0.73437029  0.03887378
  0.1208882  0.67004597  0.83422536  0.17487818  0.71460873  0.51926661
  0.55297899  0.78169805  0.77547258  0.92139858  0.25020468  0.70916855
  0.68722379  0.75378138  0.30182058  0.91982585  0.93160367  0.81539184
  0.87977934  0.07394848  0.1004181  0.48765802  0.73601437  0.59894943
  0.34601998  0.69065076  0.6768015  0.98533565  0.83803362  0.47194552
  0.84103006  0.84892255  0.04474261  0.02038293  0.50802571  0.15178065
  0.86116213  0.51097614  0.44155359  0.67713588  0.66439205  0.67885226
  0.4243969  0.35731083  0.07878648  0.53950399  0.84162414  0.24412845
  0.61285144  0.00316137  0.67407191  0.83218956  0.94473189  0.09813353
  0.16728765  0.95433819  0.1416636  0.4220584  0.35413414  0.55999744
  0.94829601  0.62568033  0.89808714  0.07021013]]
[ 0.85519803  0.48012807  0.61215557  0.58204171  0.74617288  0.66775297
  0.56814902  0.51514823  0.5400867  0.739092    0.75174732  0.6999822
  0.61943605  0.60492771  0.52450095  0.51342299  0.64972749  0.46648169
  0.63480003  0.69489821  0.57850311  0.50514314  0.76690145  0.73271342
  0.68625198  0.54510222  0.79660789  0.58773431  0.64356002  0.60958773
  0.68148959  0.6917775  0.61946087  0.66985885  0.80531934  0.5542799
  0.63388372  0.5787139  0.62305848  0.63545502  0.67331071  0.61115777
  0.74854193  0.56836022  0.78595346  0.62098848  0.63459907  0.8202189
  0.79230218  0.60074826  0.50155342  0.73577477  0.70257953  0.68166794
  0.6636407  0.44410981  0.54974625  0.57562188  0.59093375  0.58543506
  0.66386805  0.6612752  0.61828378  0.78216895  0.71679293  0.72219985
  0.58029252  0.83674336  0.66128606  0.50675907  0.70909116  0.6975331
  0.69146618  0.75743606  0.6013666  0.77701676  0.6265411  0.68727338
  0.77228063  0.60255049  0.42818714  0.52341076  0.66883513  0.49898023
  0.55327365  0.49435803  0.6057068  0.68010968  0.77800791  0.65418036
  0.69723127  0.8887319  0.52061989  0.61490928  0.63024914  0.64238486
  0.66116097  0.55118095  0.80346301  0.49154814]

這樣我們就得到了一些三維的點。

構造模型 隨後我們用 TensorFlow 來根據這些資料擬合一個平面,擬合的過程實際上就是尋找 (x, y) 和 z 的關係,即變數 x_data 和變數 y_data 的關係,而它們之間的關係剛才我們用了線性變換表示出來了,即 z = w * (x, y) + b,所以擬合的過程實際上就是找 w 和 b 的過程,所以這裡我們就首先像設變數一樣來設兩個變數 w 和 b,代碼如下:

x = tf.placeholder(tf.float32, [2, 100])
y_label = tf.placeholder(tf.float32, [100])
b = tf.Variable(tf.zeros([1]))
w = tf.Variable(tf.random_uniform([2], -1.0, 1.0))
y = tf.matmul(tf.reshape(w, [1, 2]), x) + b

在建立模型的時候,我們首先可以將現有的變數來表示出來,用 placeholder() 方法聲明即可,一會我們在啟動並執行時候傳遞給它真實的資料就好,第一個參數是資料類型,第二個參數是形狀,因為 x_data 是 2×100 的矩陣,所以這裡形狀定義為 [2, 100],而 y_data 是長度為 100 的向量,所以這裡形狀定義為 [100],當然此處使用元組定義也可以,不過要寫成 (100, )。

隨後我們用 Variable 初始化了 TensorFlow 中的變數,b 初始化為一個常量,w 是一個隨機初始化的 1×2 的向量,範圍在 -1 和 1 之間,然後 y 再用 w、x、b 表示出來,其中 matmul() 方法就是 TensorFlow 中提供的矩陣乘法,類似 Numpy 的 dot() 方法。不過不同的是 matmul() 不支援向量和矩陣相乘,即不能 BroadCast,所以在這裡做乘法前需要先調用 reshape() 一下轉成 1×2 的標準矩陣,最後將結果表示為 y。

這樣我們就構造出來了一個線性模型。

這裡的 y 是我們模型中輸出的值,而真實的資料卻是我們輸入的 y_data,即 y_label。

損失函數 要擬合這個平面的話,我們需要減小 y_label 和 y 的差距就好了,這個差距越小越好。

所以接下來我們可以定義一個損失函數,來代表模型實際輸出值和真實值之間的差距,我們的目的就是來減小這個損失,代碼實現如下:

loss = tf.reduce_mean(tf.square(y - y_label))

這裡調用了 square() 方法,傳入 y_label 和 y 的差來求得平方和,然後使用 reduce_mean() 方法得到這個值的平均值,這就是現在模型的損失值,我們的目的就是減小這個損失值,所以接下來我們使用梯度下降的方法來減小這個損失值即可,定義如下代碼:

optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss)

這裡定義了 GradientDescentOptimizer 最佳化,即使用梯度下降的方法來減小這個損失值,我們訓練模型就是來類比這個過程。

運行模型 最後我們將模型運行起來即可,運行時必須聲明一個 Session 對象,然後初始化所有的變數,然後執行一步步的訓練即可,實現如下:

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for step in range(201):
        sess.run(train, feed_dict={x: x_data, y: y_data})
        if step % 10 == 0:
            print(step, sess.run(w), sess.run(b))

這裡定義了 200 次迴圈,每一次迴圈都會執行一次梯度下降最佳化,每次迴圈都調用一次 run() 方法,傳入的變數就是剛才定義個 train 對象,feed_dict 就把 placeholder 類型的變數賦值即可。隨著訓練的進行,損失會越來越小,w 和 b 也會被慢慢調整為擬合的值。

在這裡每 10 次 迴圈我們都列印輸出一下擬合的 w 和 b 的值,結果如下:

0 [ 0.31494665  0.33602586] [ 0.84270978]
10 [ 0.19601417  0.17301694] [ 0.47917289]
20 [ 0.23550016  0.18053198] [ 0.44838765]
30 [ 0.26029009  0.18700737] [ 0.43032286]
40 [ 0.27547371  0.19152154] [ 0.41897511]
50 [ 0.28481475  0.19454622] [ 0.41185945]
60 [ 0.29058149  0.19652548] [ 0.40740564]
70 [ 0.2941508  0.19780098] [ 0.40462157]
80 [ 0.29636407  0.1986146 ] [ 0.40288284]
90 [ 0.29773837  0.19913  ] [ 0.40179768]
100 [ 0.29859257  0.19945487] [ 0.40112072]
110 [ 0.29912385  0.199659  ] [ 0.40069857]
120 [ 0.29945445  0.19978693] [ 0.40043539]
130 [ 0.29966027  0.19986697] [ 0.40027133]
140 [ 0.29978839  0.19991697] [ 0.40016907]
150 [ 0.29986817  0.19994824] [ 0.40010536]
160 [ 0.29991791  0.1999677 ] [ 0.40006563]
170 [ 0.29994887  0.19997987] [ 0.40004089]
180 [ 0.29996812  0.19998746] [ 0.40002549]
190 [ 0.29998016  0.19999218] [ 0.40001586]
200 [ 0.29998764  0.19999513] [ 0.40000987]

可以看到���隨著訓練的進行,w 和 b 也慢慢接近真實的值,擬合越來越精確,接近正確的值。

本文永久更新連結地址:https://www.bkjia.com/Linux/2018-03/151273.htm

相關文章

聯繫我們

該頁面正文內容均來源於網絡整理,並不代表阿里雲官方的觀點,該頁面所提到的產品和服務也與阿里云無關,如果該頁面內容對您造成了困擾,歡迎寫郵件給我們,收到郵件我們將在5個工作日內處理。

如果您發現本社區中有涉嫌抄襲的內容,歡迎發送郵件至: info-contact@alibabacloud.com 進行舉報並提供相關證據,工作人員會在 5 個工作天內聯絡您,一經查實,本站將立刻刪除涉嫌侵權內容。

A Free Trial That Lets You Build Big!

Start building with 50+ products and up to 12 months usage for Elastic Compute Service

  • Sales Support

    1 on 1 presale consultation

  • After-Sales Support

    24/7 Technical Support 6 Free Tickets per Quarter Faster Response

  • Alibaba Cloud offers highly flexible support services tailored to meet your exact needs.