미분의 Chain Rule, Computational Graph Backward 그리고 Gradient Descent Optimizaton

자세한 Pytorch 기술관련 내용은 여기를 참조 바랍니다: https://github.com/fantajeon/DLPytorch1.2/blob/master/Chapter1.ipynb

Computational Graph란?

계산의 순서를 기록하기 위해서 비순환 그래프(DAG, directed acyclic graph) 사용한다. 이 그래프를 computational graph라고 한다. 다른말로는 입력 변수로부터 출력 변수까지 복잡한 계산 과정을 기록한 그래프 자료구조이다. 보통 이 그래프는 2개의 과정으로 구성된다:

  1. Forward 과정(생성,추론/예측): 
    1. 생성과정: 입력에서 출력까지 계산 과정을 모두 기억을 한다. 연산(+,-,ReLU등)이 호출될때마다 즉시 그래프가 만들어 진다. 
    2. 예측과정: 입력값의 출력값을 계산할때 사용한다.
  2. Backward 과정(미분): 편미분 계산을 위해서 기억한 모든 과정을 역순서로 전파한다. 특히 오차(Loss)에 근거하여 미분값(=chain rule)을 계산한다. 이건 error backpropagation algorithm 의 일부이기도 하다. 이 후에 backward 과정을 통해서 자동으로 계산된 미분값과 learning rate(수식 $\ref{gdopt}$의 $\eta$)와 함께 병합하여 모델 학습에 사용한다.

이렇게 입력부터 출력까지의 모든 연산 과정을 기록하고, computational graph를 통하여 chain rule에기반한 미분 값을 계산을 능동적으로 가능해진다. 즉, 모든 계산 과정을 코드에 기록할 필요가 없다. 당연한듯하지만, 예전에는 하나하나 미분된 방정식 계산하고, 직접 소스 코드에 한땀한땀 하드 코딩해서 모델을 학습했다. 그래서 미리 정해진 계산만 할 수 있었다. 모델을 만들고 학습 코드를 만들었다. 하지만 computation graph를 활용하여 순서가 뒤집어 진 것이다. 이 computational graph 덕분에 범용적인 학습 라이브러리가 먼저 배포되고, 모델이 나중에 생성하는 순서가 가능해졌다.

Computation Graph의 Backward 과정

Backward 과정은 미분의 chain rule을 구현한 과정이다. 입력(모델 파라미터)의 변화가 출력의 어떠한 영향을 주는지 computation graph의 미분 방정식을 계산할 수 있다. 미분의 의미는 출력이 원하지 않는 방향으로 높아 졌다면(오류가 증가했다면), 관련 입력(파라미터)을 반대 방향으로 감소하면 될 것입니다. "모델 학습"은 이 과정을 반복적으로 방향을 줄이거나 / 높이거나 한다는 것이다. 이 반복하는 과정을 수학적으로 표현하면 다음의 수식 $\ref{gdopt}$ 처럼된다. 수식 $\ref{gdopt}$는 이미 알고있는 gradient descent optimization 기법이다. 아랫 첨자 수열($t,t+1$)로 표현해서 반복과정을 표현을 많이 한다.

$$\begin{align} w_{t+1} = w_{t} - \eta \frac{\partial L}{\partial w}  \nonumber \\ b_{t+1} = b_{t} - \eta \frac{\partial L}{\partial b} \label{gdopt} \end{align}  $$

[수식 $\ref{gdopt}$] Gradient Descent Optimizatoin 과정


미방의 Chain Rule을 품은 Computational Graph기반 Backward

[그림 1] Computational Graph와 Backward의 도식화

핵심 Insight! $L$의 변화를 $w$까지 연결: $L=e^2=(y - f(g(h(......z(w)))))^2$

흔히 언급되는 손실함수(loss function)인 $L=(y-\hat {y})^2$에서 $w$를 학습하기 위해서는 $\frac{\partial L}{\partial w} = \textup{blackbox}$을 반듯이 계산해야 한다. 하지만, 현재까지의 지면상으로는 구체적으로 어떤 모양인기 언급한 적이 없어서 $w$와 $L$은 직접적인 연관을 지을 수 없다. 지금부터 연결 짓는 과정을 설명한다. 입력부터 여러 단계의 계산을 거친 최종 결과는(오류의 크기를 나타내는) 변수 $L$이다. 그럼 두 변수를 어떻게 연관을 지을까? 바로 답은 편미분방정식의 chain rule에 있다. Chain Rule은 각 단계별로 변화량을 누적시킨 과정이라고 볼 수 있다. 

$w$부터 여러단계를 거쳐서 $L$까지 이루어 진 것에 chain rule을 적용하면, $$\frac{\partial L}{\partial w} =  \frac{\partial z}{\partial w} \frac{\partial \hat{y}}{\partial z}   \frac{\partial e}{\partial \hat{y}} \frac{\partial L}{\partial e} \label{df_L}$$ 이렇게 계산된다. 위에 그림 1처럼 붉은색 원부터 보라색 점선을 따라서 초록색 원(모델 파라미터)까지 각 노드의 미분값을 곱하면 Chain Rule과 동일하다(=backward). 파란색 원은 고정된 값(훈련 데이터셋)이므로 미분값을 계산하지 않는다. 이 처럼 computational graph의 기록된 정보를 바탕으로 단순한 곱하기 과정을 누적하면 chain rule이 완성된다. 

이 미분식 $\ref{df_L}$는 입력의 변화량과 출력의 변화량의 관계를 설명하는 공식이다. 다시말하면, $L$값의 변화량을 보고 $w$와 $b$의 이동 방향을 결정할 수 있다. 이 원리를 활용하여 손실이 적은 최적의 모델 파라미터($w$와 $b$)를 찾을 수 있다.


$w$의 Chain Rule속의 $\frac{\partial Y}{\partial X}$꼴 구체화 실습 해보기

그림 1은 아래의 방정식의 computational graph이다:
$$z = xw \nonumber$$
$$\hat {y} = z + b \nonumber$$
$$e = y - \hat{y} \nonumber$$
$$L = e^2 \nonumber$$
그리고 이것을 미분해보겠습니다.
$$\frac{\partial z}{\partial w} =x \nonumber$$
$$\frac{\partial \hat{y}}{\partial z}=1 \nonumber$$
$$\frac{\partial e}{\partial \hat{y}} = -1 \nonumber$$
$$\frac{\partial L}{\partial e} = 2e \nonumber$$
$$\frac{\partial L}{\partial L} = 1 \nonumber$$
자 이제 Chain Rule로 계산된 값을 각각 교체해 본다. 그러면 다음과 같이 계산된다:
$$ \begin{align} \frac{\partial L}{\partial w} &= \frac{\partial z}{\partial w} \frac{\partial \hat{y}}{\partial z}   \frac{\partial e}{\partial \hat{y}} \frac{\partial L}{\partial e} \\ &= (x)(1)(-1)(2e)(1)=-x(y-\hat{y}) \label{error} \end{align}$$

이 수식 $\ref{error}$에서 $e$값은 마지막 예측한 값($\hat{y}$)과 참 값($y$)의 오류로 볼 수가 있다. 학습하려는 모델인 $wx + b$에서 이 오류($e$)가 입력인 $x$를 곱하여 만들어진다(수식 $\ref{gdopt}$의 첫번째 식, $w$ 학습이다). 이 수식의 의미는 강한 입력은 강한 오차를 학습에 반영한다. 

결론

Computational graph에서 오른쪽 맨 끝($L$)에 오류가 왼쪽의 입력단에 $w$, $b$까지 전파된 것을 볼 수 있다. 즉, backward과정을 자세히 살펴보았다.

댓글 없음:

댓글 쓰기

[Rust] Ownership, Scope, Transfer Ownership 과 Borrowing

Scope, Ownership, Transfer Rust에서 사용하는 영역, 소유, 소유권 이전, 복사, 빌림을 요약해 보는 것이 목적이다. 소유(Ownership)은 Java, Go와 같은 백그라운드에서 실행하는 garbage collector가 없...