잠토의 잠망경

[ML] Multi-output Parallel Series 본문

공부/Python

[ML] Multi-output Parallel Series

잠수함토끼 2020. 2. 15. 21:29

GITHUB

https://github.com/yiwonjae/Project_Lotto/blob/master/Book_001/p107.py

0. 목표

feature의 수에 맞추서 output을 산출하는 방법이다.

1. DATA

in_seq1 = np.array([10, 20, 30, 40, 50, 60, 70, 80, 90])
in_seq2 = np.array([15, 25, 35, 45, 55, 65, 75, 85, 95])
out_seq = np.array([in_seq1[i]+in_seq2[i] for i in range(len(in_seq1))])

in_seq1 = in_seq1.reshape((len(in_seq1), 1))
in_seq2 = in_seq2.reshape((len(in_seq2), 1))
out_seq = out_seq.reshape((len(out_seq), 1))

dataset = np.hstack((in_seq1, in_seq2, out_seq))

이전과 동일한 방식이다.

2. DATA 정제

X

#Multi-output Parallel Series
import numpy as np
from numpy import ndarray


def split_sequence(sequence:ndarray, n_stpes:int)->(ndarray, ndarray):

    x, y =[], []

    for i in range(len(sequence)):

        if(i+n_steps>= len(sequence)):
            break

        x.append(sequence[i:i+n_steps,:])
        y.append(sequence[i+n_steps,:])

    return (np.asarray(x), np.asarray(y))


n_steps = 3

x, y = split_sequence(dataset, n_steps)

n_feature = x.shape[2]

y1 = y[:,0].reshape((y.shape[0], 1))
y2 = y[:,1].reshape((y.shape[0], 1))
y3 = y[:,2].reshape((y.shape[0], 1))

여기서는 y에 대한 산출만 변경된다.

Y

3. 학습

from keras import Sequential
from keras.layers import Dense, Conv1D, MaxPool1D, Flatten, Input
from keras import Model

v = Input(shape=(n_steps, n_feature))
cnn = Conv1D(64, 2, activation='relu')(v)
cnn = MaxPool1D()(cnn)
cnn = Flatten()(cnn)
cnn = Dense(50, activation='relu')(cnn)

output1 = Dense(1)(cnn)
output2 = Dense(1)(cnn)
output3 = Dense(1)(cnn)

model = Model(inputs=v, outputs=[output1, output2, output3])
model.compile(optimizer='adam', loss='mse')

model.fit(x, [y1, y2, y3], epochs=2000, verbose=1)

x value는 동일하게 유지하면서 출력만 다르다.

4. 표시

x_input = np.array([[70,75,145], [80,85,165], [90,95,185]]) # 기대 값: 100, 105, 205
x_input = x_input.reshape((1, n_steps, n_feature))
yhat = model.predict(x_input, verbose=0)
print(yhat) # [array([[101.457596]], dtype=float32), array([[107.120255]], dtype=float32), array([[208.32161]], dtype=float32)]
Comments