미분의 Chain Rule, Computational Graph Backward 그리고 Gradient Descent Optimizaton

자세한 Pytorch 기술관련 내용은 여기를 참조 바랍니다: https://github.com/fantajeon/DLPytorch1.2/blob/master/Chapter1.ipynb

Computational Graph란?

계산의 순서를 기록하기 위해서 비순환 그래프(DAG, directed acyclic graph) 사용한다. 이 그래프를 computational graph라고 한다. 다른말로는 입력 변수로부터 출력 변수까지 복잡한 계산 과정을 기록한 그래프 자료구조이다. 보통 이 그래프는 2개의 과정으로 구성된다:

  1. Forward 과정(생성,추론/예측): 
    1. 생성과정: 입력에서 출력까지 계산 과정을 모두 기억을 한다. 연산(+,-,ReLU등)이 호출될때마다 즉시 그래프가 만들어 진다. 
    2. 예측과정: 입력값의 출력값을 계산할때 사용한다.
  2. Backward 과정(미분): 편미분 계산을 위해서 기억한 모든 과정을 역순서로 전파한다. 특히 오차(Loss)에 근거하여 미분값(=chain rule)을 계산한다. 이건 error backpropagation algorithm 의 일부이기도 하다. 이 후에 backward 과정을 통해서 자동으로 계산된 미분값과 learning rate(수식 $\ref{gdopt}$의 $\eta$)와 함께 병합하여 모델 학습에 사용한다.

이렇게 입력부터 출력까지의 모든 연산 과정을 기록하고, computational graph를 통하여 chain rule에기반한 미분 값을 계산을 능동적으로 가능해진다. 즉, 모든 계산 과정을 코드에 기록할 필요가 없다. 당연한듯하지만, 예전에는 하나하나 미분된 방정식 계산하고, 직접 소스 코드에 한땀한땀 하드 코딩해서 모델을 학습했다. 그래서 미리 정해진 계산만 할 수 있었다. 모델을 만들고 학습 코드를 만들었다. 하지만 computation graph를 활용하여 순서가 뒤집어 진 것이다. 이 computational graph 덕분에 범용적인 학습 라이브러리가 먼저 배포되고, 모델이 나중에 생성하는 순서가 가능해졌다.

Computation Graph의 Backward 과정

Backward 과정은 미분의 chain rule을 구현한 과정이다. 입력(모델 파라미터)의 변화가 출력의 어떠한 영향을 주는지 computation graph의 미분 방정식을 계산할 수 있다. 미분의 의미는 출력이 원하지 않는 방향으로 높아 졌다면(오류가 증가했다면), 관련 입력(파라미터)을 반대 방향으로 감소하면 될 것입니다. "모델 학습"은 이 과정을 반복적으로 방향을 줄이거나 / 높이거나 한다는 것이다. 이 반복하는 과정을 수학적으로 표현하면 다음의 수식 $\ref{gdopt}$ 처럼된다. 수식 $\ref{gdopt}$는 이미 알고있는 gradient descent optimization 기법이다. 아랫 첨자 수열($t,t+1$)로 표현해서 반복과정을 표현을 많이 한다.

$$\begin{align} w_{t+1} = w_{t} - \eta \frac{\partial L}{\partial w}  \nonumber \\ b_{t+1} = b_{t} - \eta \frac{\partial L}{\partial b} \label{gdopt} \end{align}  $$

[수식 $\ref{gdopt}$] Gradient Descent Optimizatoin 과정


미방의 Chain Rule을 품은 Computational Graph기반 Backward

[그림 1] Computational Graph와 Backward의 도식화

핵심 Insight! $L$의 변화를 $w$까지 연결: $L=e^2=(y - f(g(h(......z(w)))))^2$

흔히 언급되는 손실함수(loss function)인 $L=(y-\hat {y})^2$에서 $w$를 학습하기 위해서는 $\frac{\partial L}{\partial w} = \textup{blackbox}$을 반듯이 계산해야 한다. 하지만, 현재까지의 지면상으로는 구체적으로 어떤 모양인기 언급한 적이 없어서 $w$와 $L$은 직접적인 연관을 지을 수 없다. 지금부터 연결 짓는 과정을 설명한다. 입력부터 여러 단계의 계산을 거친 최종 결과는(오류의 크기를 나타내는) 변수 $L$이다. 그럼 두 변수를 어떻게 연관을 지을까? 바로 답은 편미분방정식의 chain rule에 있다. Chain Rule은 각 단계별로 변화량을 누적시킨 과정이라고 볼 수 있다. 

$w$부터 여러단계를 거쳐서 $L$까지 이루어 진 것에 chain rule을 적용하면, $$\frac{\partial L}{\partial w} =  \frac{\partial z}{\partial w} \frac{\partial \hat{y}}{\partial z}   \frac{\partial e}{\partial \hat{y}} \frac{\partial L}{\partial e} \label{df_L}$$ 이렇게 계산된다. 위에 그림 1처럼 붉은색 원부터 보라색 점선을 따라서 초록색 원(모델 파라미터)까지 각 노드의 미분값을 곱하면 Chain Rule과 동일하다(=backward). 파란색 원은 고정된 값(훈련 데이터셋)이므로 미분값을 계산하지 않는다. 이 처럼 computational graph의 기록된 정보를 바탕으로 단순한 곱하기 과정을 누적하면 chain rule이 완성된다. 

이 미분식 $\ref{df_L}$는 입력의 변화량과 출력의 변화량의 관계를 설명하는 공식이다. 다시말하면, $L$값의 변화량을 보고 $w$와 $b$의 이동 방향을 결정할 수 있다. 이 원리를 활용하여 손실이 적은 최적의 모델 파라미터($w$와 $b$)를 찾을 수 있다.


$w$의 Chain Rule속의 $\frac{\partial Y}{\partial X}$꼴 구체화 실습 해보기

그림 1은 아래의 방정식의 computational graph이다:
$$z = xw \nonumber$$
$$\hat {y} = z + b \nonumber$$
$$e = y - \hat{y} \nonumber$$
$$L = e^2 \nonumber$$
그리고 이것을 미분해보겠습니다.
$$\frac{\partial z}{\partial w} =x \nonumber$$
$$\frac{\partial \hat{y}}{\partial z}=1 \nonumber$$
$$\frac{\partial e}{\partial \hat{y}} = -1 \nonumber$$
$$\frac{\partial L}{\partial e} = 2e \nonumber$$
$$\frac{\partial L}{\partial L} = 1 \nonumber$$
자 이제 Chain Rule로 계산된 값을 각각 교체해 본다. 그러면 다음과 같이 계산된다:
$$ \begin{align} \frac{\partial L}{\partial w} &= \frac{\partial z}{\partial w} \frac{\partial \hat{y}}{\partial z}   \frac{\partial e}{\partial \hat{y}} \frac{\partial L}{\partial e} \\ &= (x)(1)(-1)(2e)(1)=-x(y-\hat{y}) \label{error} \end{align}$$

이 수식 $\ref{error}$에서 $e$값은 마지막 예측한 값($\hat{y}$)과 참 값($y$)의 오류로 볼 수가 있다. 학습하려는 모델인 $wx + b$에서 이 오류($e$)가 입력인 $x$를 곱하여 만들어진다(수식 $\ref{gdopt}$의 첫번째 식, $w$ 학습이다). 이 수식의 의미는 강한 입력은 강한 오차를 학습에 반영한다. 

결론

Computational graph에서 오른쪽 맨 끝($L$)에 오류가 왼쪽의 입력단에 $w$, $b$까지 전파된 것을 볼 수 있다. 즉, backward과정을 자세히 살펴보았다.

MathJax와 TikJax 태그 삽입



<!--MathJax 3 start-->
<script>
MathJax = {
  tex: {
	inlineMath: [['$', '$'], ['\\(', '\\)']],
    displayMath: [ ['$$','$$'], ["\\[","\\]"] ],
    processEscapes: false,
  	tags: 'all',
  }
};
</script>
    
<script src="https://polyfill.io/v3/polyfill.min.js?features=es6">
<script async='async' id='MathJax-script' src='https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js'/>
<script async='async' id='MathJax-script' src='https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-chtml.js'/>

<!-- MathJax 3 end -->

<!-- tikzjax start --> 
<link href='https://tikzjax.com/v1/fonts.css' rel='stylesheet' type='text/css'/>
<script src='https://tikzjax.com/v1/tikzjax.js'/>
<!-- tikzjax end -->
<!-- Highlight start -->
<link href='//cdnjs.cloudflare.com/ajax/libs/highlight.js/9.15.6/styles/atom-one-dark-reasonable.min.css' rel='stylesheet'/>

<script src='//cdnjs.cloudflare.com/ajax/libs/highlight.js/9.15.6/highlight.min.js'/>
<script src='//cdnjs.cloudflare.com/ajax/libs/highlight.js/9.15.6/languages/r.min.js'/>

<script>hljs.initHighlightingOnLoad();</script>
<!-- Highlight end --> 
https://tohtml.com/jScript/

Gaussian Mixture Model, GMM

왜 Gaussian Mixture Model (GMM)을 기법을 사용할까?

관측한 변수의 분포가 복잡하여 설명이 힘들때 혼합 모델(mixture model)을 사용하여 더 쉽게 설명하기 위해서이다. 이때 보통 잠재 변수(latent variable)를 도입하고, 보통 잠재변수의 분포 모델은 수학적/계산적으로 용이한 모델을 사용한다. 즉, 단순한 모델의 조합으로 복잡한 모델을 설명한다. 결국, 잠재변수의 형태를 사전에 정해야 한다. 예를들면 잠재변수는 gaussian 분포를 따른다고 가정한다. 이러면 Gaussian Mixture Model이 된다.

최종적으로 단순한 모델의 인자들이 추정이 끝이나면, 우리는 복잡한 모델을 잘 설명하는 인자를 추정하게 된 것이다.

잠재 변수($z$)로 복잡한 모델을 설명하기

아래의 그림에서 $p(X)$ 분포를 어떻게 모델링해야 하지? 

그림 1 $p(x)$
[그림 1] $p(x)$의 분포

$p(x)$를 보면 특징이 있다. 바로, 봉오리(peak)가 3개가 있다는 것이다($\mu$). 그리고 좌우로 퍼진 정도는 다르다($\sigma$). 그럼 3개의 가우시안 모델을 가지고 설명을 할 수 있다고 생각할 것이다. 그러면 잠재변수($z$)를 3개 놓고, $z$는 가우시안 분포를 따를 것이라 가정하자. 그러면 아래 그림2와 같이 잠재변수의 분포를 가우시안(gaussian) 모델로 그림 1을 설명할 것이다. 물론 오차는 존재한다.




[그림 2] 3개의 잠재변수 $z$로 설명한 $p(x)$ 분포.   여기서, $p(x|z_k)=g(x|\mu_k,\sigma_k)$의 의미는 $z_k$가 주어졌을때의 $x$의 분포이다.

Marginal $x$을 한다는 것은? 잠재변수 $z$의 소거. 그것의 의미는?

단순히 잠재변수를 제거하여 $z$를 고려하지 않겠다는 표현이다. 즉, $z$를 모두 더하여 소거한다. 직관적인 의미를 알아보면 3개의 가우시안 모델이 주어졌을때, 오직 $x$의 관점에서 분포를 알기 위하여 $z$를 선형으로 중첩한 것이다. 결론적으로 잠재변수의 분포 가정과 marginal 연산을 사용하여, 가우시안 모델로 복잡한 $p(x)$를 설명 할 수 있게 된다. 마치 그림 3과같이 3개를 그래프(녹색, 붉은색 그리고 하늘색)를 중첩(표시)해서 핑크색 $p(x)$를 만드는 것과 비슷하다. 이것의 수학적 의미로 잠재변수를 소거를 위한 "marginal한다"는 것과 비슷하다. 물론, 관측된(파란색) $p(x)$와 중첩된(핑크색) $p(x)$사에이는 오차가 있다.



수학의 언어로 모델링 과정을 살펴보자.

잠재변수와 $x$를 수학적 심볼로 해석을 해보자. 우리의 직관을 매우 정교하게 설명한 것뿐이다. 잠재변수($z$)가 있다고 가정하면, 두 변수를 결합한 확률은 단순하게 $$p(x,z) \label{p_x_z}$$ 결합확률 분포로 표현된다. 이것만으로는 $z$로부터 $x$를 표현할 수 없다. 여기서 중첩원리를 이용하자. 바로 marginal이다. 

그러면, $x$의 marginal은 
$$p(x) = \sum_{k=1}^{K} p(x,z_k)$$
가 될 것이다. 하지만, 아직도 $p(x,z_k)$를 우리가 아는 가우시안 모델로 연결을 시키기에는 역부족이다. Factorization기법[참조] 이용해보자. 여기서 $p(x,z_k)$는 $z$로부터 $x$가 생성되는 것을 가정한다면 factorization 원리의 따라서, $p(x,z_k) = p(x|z_k)p(z_k)$가 된다. 
그리므로  $$p(x) = \sum_{k=1}^{K} p(z_k)p(x|z_k) \label{factorization}$$

방정식 $\ref{factorization}$에서 무언가 구체화 됬다($z_k$로부터 $x$를 표현할 수 있다). $p(x|z_k)$는 그림 2에서 한 개의  가우시안 모델로 구체화를 할 수 있다(likelihood이다). Likelihood는 모델의 인자가 정해질 경우 $x$의 확률을 계산할 수 있다. 즉, 그림 2에서의 각 색별로 가우시안 분포를 계산할 수 있다. 수학적으로 자세히 써보면, 가우시안에 근거한 likelihood $p(x|z_k)$는 $N(x|\mu_k,\sigma_k)$가 된다. 잠재변수 $z_k$의 가우시안 모델이라고 할 수 있다.
$p(z_k)$는 $z_k$의 가중치가 된다. 더군다나 $x$와는 상관없는 확률이다. 단순한 더하기의 중첩이 아니라, 가운시안 모델, $p(x|z_k)$의 가중치의 합이다. 

최종적으로 
$$\begin{align}p(x) &=& \sum_{k=1}^{K} p(z_k)N(x|\mu_k,\sigma_k) \nonumber \\ &=& \sum_{k=1}^{K}w(z_k)*N(x|\mu_k,\sigma_k) \label{nzk} \end{align}$$

$p(x,z)$[$\ref{p_x_z}$]의 분포로부터 잠재변수 가우시안 분포[$\ref{nzk}$]까지 구체화를 하였다.

K-Clustering과 GMM의 관계는?

만약 $z_k$가 one-hot encoding일때를 칭한다. One-hot encoding은 1개만 1일고 나머지는 0인 배열을 가지는 것이다. 즉, $z_1$ = [1,0,0], $z_2$ = [0,1,0], $z_3$ = [0,0,1]이 된다. 그리고 K는 여기서 3이다.

그림 2에서 처럼  $x$위에 모든 점들은 likelihood( $p(x|\mu_k,\sigma_k)$ )가 가장 큰 값에 해당하는 색깔을 계산적으로 찾을 수 있다($k=\underset{k}{\operatorname{argmax}} \ {p(x|z_k)}$). 이는 결국 빨강, 녹색과 하늘색 중 하나에 귀속을 판정할 수 있다.

EM을 사용하여 GMM의 인자를 찾는다.

이 모델을 $z_k$를 찾는 방법은 보통 MLE(Maximum Likelihood Estimation)의 해법으로 EM(Expectation Maximization) algorithm을 많이 사용한다.
EM은 보통 2단계로 구성된다:
  1. E단계: $(\mu,\sigma)$를 고정하고, 모든 데이터의 $p(z_k|\mu,\sigma)$의 값을 계산한다.
  2. M단계: 가장 잘 설명하는 모델의 파라미터($\mu,\sigma$)를 새롭게 갱신한다.
  3. 1,2단계 반복, 단, 모델 파라미터의 변화가 작으면 종료한다.
자세한 EM은 다른 글에서 설명하기로 하자.



Importance Sampling이란?

복잡한 함수 $p(z)$를 가지는 $f(x)$의 기대값을 계산하고 싶지만, 수학적으로 계산하기 어려울 경우 사용하는 방법 중 하나이다. 특히 $p(z)$로부터 샘플을 추출하기가 어려울 경우 사용한다. 바로 $p(z)$ 대신 단순한 $q(z)$의 샘플을 사용하여 $f(z)$의 기대값을 계산한다. 비록 $p(z)$는 어려운 형태이지만,  모든 $z$값이 주어지면 $p(z)$의 값을 쉽게 계산할 수 있어야 한다. 더 자세히 말하면, $p(z)$보다는 $\tilde{p}(z) / Z_p$에서 $\tilde{p}(z)$이다. 이유는 $Z_p$(정규항)를 계산하기 어렵기 때문이다. 정규항이란 값을 정규범위 값으로 만들어주는 항이다. 모르기 때문에 그냥 퉁쳐서 역할상의 의미로 정규항이라고 칭한다. 여기서는 확률 분포이므로, 0~1까지의 값으로 만들어 준다.


Importance sampling 공식 살펴보기

$$E_{z \sim p}[f] \approx  \frac{1}{L} {\sum_{l=1}^{L}{  \color{Blue}\frac{p(z^l)} {q(z^l)}} f(z) } = \frac{1}{L} {\sum_{l=1}^{L}{  \color{Blue}w(z^l) }f(z) } \tag{1} \label{ef}$$

여기서 $\frac{p(q^l)}{q(z^l)}$이 샘플 $z^l$의 가중치(weight)라고 볼 수 있다. 즉, 가중치는 두 $p(z)$와 $q(z)$의 비율이라고 해석할 수 있다. 이 $q(z)$와 $p(z)$는 샘플로부터 근사적으로 계산할 수 있는 값이다. 즉, $w(z)$는 바로 계산이 가능하다. 단, $\sum_{l=1}^{L} w(z^{l}) = 1$이 되도록 정규화를 하면된다. 

Importance sampling 기법은 특별히 샘플 $z$를 버리거나 하지 않는다. 무조건 $f(z)$의 기대값[$\ref{ef}$]을 계산할 때 사용한다.

좋은 $q(z)$는 무엇일까?

  1. $p(z)$와 가능한한 모습이 비슷하면 좋다.
    1. 샘플링 낭비가 발생한다.
      1. Uniform sampling으로 하면 샘플$z^{l}$을 낭비할 수 있다.
      2. 모양이 다르면 $w(z^{l})$의 값이 0에 가까울 수 있다.
  2. 샘플링이 쉬운 확률 분포이면 좋다.
  3. 결과적으로 사용자의 안목이 중요하다.

[Rust] Ownership, Scope, Transfer Ownership 과 Borrowing

Scope, Ownership, Transfer Rust에서 사용하는 영역, 소유, 소유권 이전, 복사, 빌림을 요약해 보는 것이 목적이다. 소유(Ownership)은 Java, Go와 같은 백그라운드에서 실행하는 garbage collector가 없...