1.我先上源碼
下面的代碼是一個人寫的lstm輸入資料處理的:
def load_data(filename, seq_len, normalise_window): f = open(filename, 'rb').read() data = f.split('\n') sequence_length = seq_len + 1 result = [] for index in range(len(data) - sequence_length): result.append(data[index: index + sequence_length]) if normalise_window: result = normalise_windows(result) result = np.array(result) row = round(0.9 * result.shape[0]) train = result[:row, :] np.random.shuffle(train) x_train = train[:, :-1] y_train = train[:, -1] x_test = result[row:, :-1] y_test = result[row:, -1] x_train = np.reshape(x_train, (x_train.shape[0], x_train.shape[1], 1)) x_test = np.reshape(x_test, (x_test.shape[0], x_test.shape[1], 1)) return [x_train, y_train, x_test, y_test]
沒前進一步,取一個步長的資料
比如,來源資料是
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
然後我約定的步長是3,那麼取得的資料是[1,2,3], [2,3,4], ..., [13,14,15]
2.代碼解釋
x_train = train[:, :-1]
這個代碼錶示取出train中的每一行,但是對於列,最後的一列不取
GitHub源碼:LSTM_learn