분류 문제를 풀다보면 반드시 만나는 손실함수가 있다. 그것은 바로 cross entropy loss. cross entropy loss가 무엇을 의미하는지, 왜 수식이 그렇게 정의되는지에 대한 탐구는 아래 포스팅에서 진행했으니 참고하길 바란다.
https://steady-programming.tistory.com/89
[ML / DL] Cross entropy loss function in classification problem
지난 포스팅에서 회귀 문제와 분류 문제에서의 loss function은 분포를 가정하고 세운 maximum likelihood function의 간단한 변형임을 살펴보았다.https://steady-programming.tistory.com/36 [ML / DL] Cost function과 Maximu
steady-programming.tistory.com
또한 cross entropy loss, mean squared error loss의 backpropagation은 어떻게 유도되는지 아래 포스팅에서 살펴보았다.
https://steady-programming.tistory.com/90
[ML / DL] Backpropagation of loss function
지도학습의 머신러닝은 회귀 문제 / 분류 문제 중 하나에 속한다. 회귀 문제라면 최종 레이어에서 계산한 예측값 ($\hat{y}_i$)과 실제값 ($y_i$)와의 mean squared error $( L = \dfrac{1}{2} \sum (y_i - \hat{y}_i)^2
steady-programming.tistory.com
본 포스팅에서는 cross entropy loss을 pytorch에서 어떻게 사용될 수 있고 input은 무엇인지, gradient는 실제로 유도한 식대로 계산되는지 살펴볼 예정이다. pytorch에서는 cross entropy loss을 정의하는 방식이 크게 1) nn.CrossEntropyLoss와 2) nn.NLLLoss가 있다. 각 함수의 input은 무엇인지 gradient는 어떻게 계산되는지 살펴보자.
nn.CrossEntropyLoss
pytorch 공식 문서를 살펴보자.
공식문서에 input logits과 targets간의 cross entropy을 계산한다고 적혀있다. logits은 logistic function (또는 softmax)을 태우기 이전 함수값이다.
\[ a_i = \dfrac{exp(z_i)}{\sum^C_{k=1} exp(z_k)} \]
softmax function을 예로 들면 위 함수에서 input으로 들어가는 함수값 $z_k$가 logits이 된다. 즉, nn.CrossEntropyLoss 함수의 입력값으로는 softmax에 의해 변환된 확률이 아닌, 그 이전의 logits 값과 target이 들어가는 것이다.
여기서 잠깐, cross entropy loss의 정의를 다시 한번 짚고 넘어가자. $N$개의 데이터가 있고 $C$개의 레이블이 있는 상황이다.
\[ L = - \dfrac{1}{N} \sum^N_{i=1} \sum^C_{j=1} p_{ij} \log q_{ij} \]
여기서 $q_{ij}$는 모델에 의해 나온 예측 확률값이고 logits이 input으로 들어갔다. 잘 생각해보면 nn.CrossEntropyLoss는 input으로 logit과 target을 받지 않는가. 여기서 자연스럽게 유추해볼 수 있는 것은, nn.CrossEntropyLoss이 logits을 input으로 받지만 softmax도 해주고, log도 씌워줌을 생각해볼 수 있고 이는 실제로 맞다. 이점이 nn.NLLLoss와 가장 큰 차이점이다. nn.NLLLoss은 log_softmax가 된 값을 input으로 받는 반면에, nn.CrossEntropyLoss은 logits을 input으로 받고 내부적으로 log_softmax을 해준다. nn.CrossEntropyLoss가 더 많은 일을 하는 것이다.
코드로 간단하게 확인해보자. pytorch 공식 문서에 있는 예시를 가져왔다.
import torch.nn as nn
loss = nn.CrossEntropyLoss()
# Example of target with class indices
loss = nn.CrossEntropyLoss()
input = torch.tensor([[ 0.0717, 0.1584, 0.0227, -0.2727, 0.0198],
[ 0.0687, 0.0615, -0.2146, 0.0322, 0.0522],
[-0.2973, 0.0801, 0.0482, 0.0385, 0.1304]], requires_grad=True)
target = torch.tensor([3, 2, 0])
output = loss(input, target)
output.backward()
class label의 개수가 5개인 logit `input`과 `target`을 생성한다. nn.CrossEntropyLoss는 input으로 logit을 받으므로 바로 input으로 집어넣는다. output loss는 아래와 같다.
output
"""
tensor(1.8797, grad_fn=<NllLossBackward0>)
"""
input, 즉 logit에 대한 gradient을 확인해보자. 수식상으로 $\dfrac{\partial L}{\partial z}$ 이다.
input.grad
"""
tensor([[ 0.0709, 0.0773, 0.0675, -0.2831, 0.0673],
[ 0.0710, 0.0705, -0.2798, 0.0685, 0.0699],
[-0.2843, 0.0715, 0.0692, 0.0685, 0.0751]])
"""
이 gradient는 어떻게 계산되었을까? 이전 포스팅에서 살펴보았듯이, $(\hat{y} - y) / n$ 이므로 직접 손으로 확인해주면 된다.
softmax = nn.Softmax(dim=1)
(softmax(input) - F.one_hot(target, num_classes=5)) / 3
"""
tensor([[ 0.0709, 0.0773, 0.0675, -0.2831, 0.0673],
[ 0.0710, 0.0705, -0.2798, 0.0685, 0.0699],
[-0.2843, 0.0715, 0.0692, 0.0685, 0.0751]], grad_fn=<DivBackward0>)
"""
동일함을 확인할 수 있다.
nn.NLLLoss
pytorch 공식 문서를 확인해보자.
공식 문서에 친절하게 log probability가 input으로 들어가고 이를 구현하기 위해 neural network을 구성할 때, 마지막 layer에 logsoftmax layer을 추가하라고 나와 있다. (귀찮으면 CrossEntropyLoss을 사용하라고 나와있다.)
실제로 동일한지 아래 pytorch 코드로 확인해보자.
log_softmax = nn.LogSoftmax(dim=1)
output_nll = nll(log_softmax(input), target)
print(output, output_nll)
"""
(tensor(1.8797, grad_fn=<NllLossBackward0>),
tensor(1.8797, grad_fn=<NllLossBackward0>))
"""
계산된 loss가 동일함을 확인할 수 있다.
Conclusion
nn.CrossEntropyLoss와 nn.NLLLoss을 살펴보았다. 내부적으로 어떤 차이점이 있는지, gradient는 실제로 수학적으로 유도된 식과 동일한지 살펴보았다. 코드를 살펴보면 사람들은 nn.CrossEntropyLoss을 더 많이 사용하는 것 같다. 아마도 내부적으로 log_softmax까지 해주기 때문이지 않을까? 또한 softmax layer을 통과하는 tensor는 weight을 필요로 하지 않고 upstream gradient을 전달만하면 되기 때문에 중요도가 상대적으로 떨어지고, 그래서 그냥 한꺼번에 nn.CrossEntropyLoss으로 처리하는게 편해서 이걸 더 많이 사용하는 것 같기도 하다. 어찌됐든 pytorch을 사용하다보면 어떤게 더 편한지 알게 되겠지..
'ML&DL > Pytorch' 카테고리의 다른 글
[Pytorch] MSE Loss in regression problem (1) | 2024.06.10 |
---|
댓글