2017. 5. 12. 11:44ㆍmachine learning
어떤 특징에 대한 답을 학습시키는 것이 아니라 특징의 문맥도 학습시키자라는 개념에서 나온 것이 RNN이라고 할 수 있다. 예를 들어 노래를 부를 때 처음부터 부르면 노래를 부르기가 쉽지만 갑자기 중간부터 부르라고 하면 어떤 가사인지 생각나지 않는 경우가 있다. 이런 식으로 sequence 자체를 기억시키자는 개념에서 나온 것이 바로 RNN이다.
간단한 encoder - decoder diagram을 살펴보자.
X(t)는 t시간에 들어온 입력 값, Win은 인코딩하기 위한 vector, Z(t)는 인코딩된 결과, Wout은 디코딩 하기 위한 vector, Y(t)는 t시간에서의 디코딩 결과 값이다.
이번에는 여기다가 시간 정보를 추가해 보자.
t시간 뿐만 아니라 ...,t-2,t-1,....,t+1,t+2... 시간 정보도 추가가 되었으며 W값이 추가가 되었다. W 값을 통해 우리는 이전 데이터로부터 영향력을 끼치거나 받을수가 있다. 위의 그림이 전형적인 RNN의 구조를 나타낸다.
이제 실제로 RNN을 이용해서 시계열 데이터를 예측해 보자. 데이터는 국제선 승객의 추이 분포를 가져와서 예측할 것이다. 아래의 사이트로 접근하자.
사이트에 접근했으면, 왼쪽의 Export 탭을 클릭한 후 CSV(,) 메뉴를 클릭해서 csv로 export를 하자. 그러면 csv file로 download가 된다. 텍스트 편집기를 열어 csv 파일을 연 후 맨 첫 라인과 맨 마지막 라인을 지워주자. (description 부분이며 실제 data 수치가 아니라 제거해야 함)
학습할 데이터가 준비가 되었다면 이제 아래의 코드를 살펴보자.
load_series 함수를 통해 csv 파일을 읽어서 데이터를 뽑은 후, normalization 까지 시켜주자. 그 후에 split_data 함수를 통해 train과 test data로 나누자.
이제 SeriesPredictor class를 살펴보자. 이 class를 통해 우리는 data를 학습하고 학습된 모델을 이용해 추론할 것이다.
class 초기값으로 입력 값의 차원 수, sequence 크기, hidden unit의 차원 수를 파라미터로 받고 있다.
class 내의 model 함수를 살펴보자. tensorflow에서 제공하는 LSTMCell을 사용하였고, dynamic_rnn 함수를 통해 rnn을 실행시킨다. 결과 값으로 나타나는 output과 states는 각각 RNN 입력에 대한 output 값들과 마지막 hidden 상태 값을 나타낸다. 아래 그림을 보면 이해가 쉽다.
위 코드에서는 각각 100차원을 가진 5개의 output들이 출력되고, 하나의 state 값이 출력될 것이다. 5 x 100 차원의 값들을 하나로 합치기 위해 100 x 1의 weight를 행렬 곱 한 후 bias를 더해 주자. 이렇게 나온 값과 실제 label의 값 들을 비교해서 cost 만큼 값을 수정해 나가면 된다.
여기서 input에 대한 정답은 input +1 에 있는 값들이 된다. 예를 들어 [1, 2, 3, 4, 5, 6, 7]이라는 순차 데이터가 있고 input이 [1,2,3,4,5] 라면 정답은 [2,3,4,5,6] 이 되는 셈이다. 즉 [1,2,3,4,5] 라는 입력이 들어 왔을 경우에 6이라는 값을 찾아 낼 수가 있다.
그럼 이제 실제로 모델을 돌려보자.
아까 위에서 언급했던 방식으로 가져온 데이터를 불러온 후 (load_series), train과 test set을 나누자. 그런 후 train set과 test set을 파라미터로 넘겨서 모델을 학습 시키자. 여기서 test set은 validation set 역할을 하고 있음을 알 수가 있다. step 100번 마다 test set을 돌려서 test_err가 떨어지고 있다면 min_test_err를 갱신하고 떨어지고 있지 않다면 그 시점에서 training을 중단시킨다. (여기서는 patience 변수를 둬서 해당 개수가 0이 될때까지 기다리고 있다.)
모델 훈련이 끝났다면, predictor.test 함수를 통해 실제 test를 진행한다. 실제로는 test set을 따로 만들며, validation set을 test set에서 이용하지는 않는다.
아래 그림과 같이 prediction 된 결과를 확인해 볼 수 있다.