LSTM相關的Python代碼__Python

來源:互聯網
上載者:User

1.我先上源碼

下面的代碼是一個人寫的lstm輸入資料處理的:

def load_data(filename, seq_len, normalise_window):    f = open(filename, 'rb').read()    data = f.split('\n')    sequence_length = seq_len + 1    result = []    for index in range(len(data) - sequence_length):        result.append(data[index: index + sequence_length])        if normalise_window:        result = normalise_windows(result)    result = np.array(result)    row = round(0.9 * result.shape[0])    train = result[:row, :]    np.random.shuffle(train)    x_train = train[:, :-1]    y_train = train[:, -1]    x_test = result[row:, :-1]    y_test = result[row:, -1]    x_train = np.reshape(x_train, (x_train.shape[0], x_train.shape[1], 1))    x_test = np.reshape(x_test, (x_test.shape[0], x_test.shape[1], 1))      return [x_train, y_train, x_test, y_test]

沒前進一步,取一個步長的資料

比如,來源資料是

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
然後我約定的步長是3,那麼取得的資料是[1,2,3], [2,3,4], ..., [13,14,15]

2.代碼解釋

x_train = train[:, :-1]


這個代碼錶示取出train中的每一行,但是對於列,最後的一列不取

GitHub源碼:LSTM_learn

相關文章

聯繫我們

該頁面正文內容均來源於網絡整理,並不代表阿里雲官方的觀點,該頁面所提到的產品和服務也與阿里云無關,如果該頁面內容對您造成了困擾,歡迎寫郵件給我們,收到郵件我們將在5個工作日內處理。

如果您發現本社區中有涉嫌抄襲的內容,歡迎發送郵件至: info-contact@alibabacloud.com 進行舉報並提供相關證據,工作人員會在 5 個工作天內聯絡您,一經查實,本站將立刻刪除涉嫌侵權內容。

A Free Trial That Lets You Build Big!

Start building with 50+ products and up to 12 months usage for Elastic Compute Service

  • Sales Support

    1 on 1 presale consultation

  • After-Sales Support

    24/7 Technical Support 6 Free Tickets per Quarter Faster Response

  • Alibaba Cloud offers highly flexible support services tailored to meet your exact needs.