순서가 있는 데이터를 위한 딥러닝 기본 - RNN BPTT

 

RNN의 학습, BPTT (Back Propagation Through Time)

 

오른쪽에 펼친거 보면, backpropogation 일어난 후, 기존 hidden state 가중치에 업데이트 됨

  • 먼저 loss 함수를 구체적으로 기술하자면,

softmax와 CrossEntropyLoss

를 사용하는 Classification 모델이라면,

  • 가중치 Wyh는 전시간에 공유되므로, Wyh를 기반으로 미분가능

  • 마찬가지로 Lt+1dms ht+1과 전시간에 공유되는 Whh를 기반으로 다음과 같이 표현할 수 있음

  • 이를 ht까지 확장하면, 다음과 같이 표현할 수 있음 (ht+1 계산에 ht가 사용되므로 다음과 같이 미분 가능)

앞 단계 hidden state 값은 그 전 단계 hidden state 값으로 미분 가능

-> 체인룰임

예를 들어,

이렇게 표현할 수 있음.

  • 따라서, Lt+1은 다음과 같이 표현할 수 있음

그 시점의 hidden state의 미분한 값이 예전시간부터 지금시간까지 다 곱해진다. -> loss를 기반으로 미분값(가중치) 계산할 때 곱셈으로 들어감.

  • 즉, 각 시점에서 발생한 hidden state까지의 미분값을 모두 합친 것
  • hidden state는 tanh 함수를 사용하고 tanh 함수의 미분은 

tanh 의 미분 값은 0과 1 사이임 (전 게시물 그래프 참고) -> 곱할수록 그 전 값보다 작아짐. (ex. 0.2 * 0.2 = 0.04)

이므로,

hidden state 미분값이 계속 곱해지면, 결국 gradient vanishing 문제 발생

  • 이로 인해, 긴 sequence 를 가진 데이터(time-step이 긴 데이터)는 가중치 업데이트가 거의 되지 않는 문제 발생
    • 이를 긴 sequence를 기억하지 못한다고 표현함

 

결론)

긴 sequence는 RNN이 해결하지 못함