[Fundamentals]/[Fundamentals] 논문

[201804]Mean teachers are better role models: Weight-averaged consistency targets improve semi-supervised deep learning results

do.hyeon 2025. 1. 6. 16:16

Author: Antti Tarvainen and Harri Valpola | Arxiv, 2018

https://arxiv.org/abs/1703.01780

 

요약

 이 논문에서는 unlabeled data를 labeled data와 함께 학습에 활용하기 위한 Semi-supervised model을 소개한다. 제안하는 모델은 Mean Teacher로, 기존에 SSL에 사용되던 gamma-model이나 Temporal ensembling보다 더 우수한 성능을 낸다.  또 이 방법 덕분에 각 에포크 이후 업데이트되던 Temporal Ensembling과 달리, 매 학습 스텝마다 모델의 평균 가중치를 업데이트할 수 있게 되어 더 빠르게 학습할 수 있게 되었다.

 teacher-student 구조인 이 모델은 student가 먼저 labeled data에 대해 예측을 하고, teacher model도 같은 input에 대해 예측을 수행한다. student model은 (1)실제 값과 예측값의 오차를 classification cost(Cross-Entropy)를 통해 학습하고, (2)classification cost 대신 student model의 EMA에 의해 업데이트된 teacher model과의 예측 오차인 consistency cost(MSE)를 추가 제약으로 적용한다.

 

 

기본 개념 정리

Gamma model

Ladder Networks를 단순화한 모델이다.
(1) Ladder Networks는 인코더와 디코더로 이루어져 있는 SSL 모델로, 인코더는 학습 과정에서 데이터에 노이즈를 주입하고 디코더는 노이지한 출력 y~로부터 깨끗한 출력 y를 복원하는 방향으로 denoising function g를 학습한다. 이를 반복하며 인코더에 Cost function C_d를 계속 업데이트하고, 학습이 완료되면 인코더는 디코더 없이도 깨끗한 출력 y, 즉 더 robust한 예측을 출력할 수 있도록 조정된다.
인코더는 x - f(1) - f(2) y~ 방향으로 작동하고, 디코더는 x -> z~(1) - g(1) - g(0) 방향으로 작동한다.
(2) Ladder Networks는 여러 계층에 lateral connection이 존재하고 있다. 반면 Gamma model은 이를 단순화하여, 가장 높은 층에만 lateral connection이 존재한다. 덕분에 더 높은 계산 효율성을 가진다.

본 논문에서는 Encoder가 student model로서 teacher model인 Decoder와의 출력 일관성을 학습한다는 방식으로 이 개념에 접근한다.
Pie model & Temporal Ensembling

이 두 모델 전부 Self-assembling을 활용한 Pseudo Ensemble Agreement 기법이다.

Temporal Ensembling을 Pie model의 개선된 버전이라고 본다.
Pie model은 동일한 데이터에 대해 두 번의 augmentation과 dropout이 적용되고, (1)첫 번째의 데이터에 대하여 기존에 알고 있던 분류 지식으로 예측에 따라 cross-entropy를 계산하고, (2) 이 데이터와 두 번째 데이터(다른 증강과 드롭아웃이 적용된) 데이터 간 squared difference를 계산한다. 이를 통해 다른 비슷한 데이터에도 일관성 있는 예측을 할 수 있도록 조정된다.

Temporal Ensembling은 augmentation과 dropout은 한 번만 한다. cross-entropy를 계산하는 것도 Pie model과 같지만, 이전 스텝의 데이터와 squared difference를 계산한다는 점이 차이점이다. 또 현재 스텝의 데이터를 저장해 두고, 다음 스텝에서 squared difference 계산에 활용하는 식으로 작동한다. 이때 저장되는 Z는 EMA를 통해 업데이트된다.

Temporal Ensembling은 한 번만 augmentation & dropout을 적용하므로 계산 효율성이 더 높고(당연히 메모리는 더 많이 듦), 이전 에포크 정보를 누적하므로 학습이 더욱 안정적이다.
EMA(Exponential Moving Average)

데이터의 부드러운 추세를 파악하기 위해 Moving Average를 사용한다.
Moving Average에는 크게 세 가지가 있다.
(1) Simple Moving Average: 특정 시점 데이터 n1, n2, n3이 있다고 하면 (n1+n2+n3)/3으로 계산한다.
(2) Weighted Moving Average: 가중치 w1, w2, w3이 있다고 하면 (w1n1+w2n2+w3n3)/(w1+w2+w3)으로 계산한다.
(3) Exponential Moving Average
- Z_t = ɑZ_{t-1} + (1-ɑ)z_t로 계산한다.
- Z는 EMA값, z는 특정 시점에서의 실제값, ɑ는 Momentum or Smoothing factor이다.(작을 수록 최신 데이터 비중 높음)
- 가중변수를 이용하여 과거 데이터의 영향력을 낮추고 현재 데이터의 영향력을 높이는 등 조정이 가능하다는 장점이 있다.

 

 

Intro: SSL의 필요성과 기존 기법의 한계

 딥러닝 기술이 빠르게 발전하고 있지만, 그렇다고 해서 실생활 모든 분야에 딥러닝이나 인공지능 기법을 적용할 수는 없다. 가장 큰 이유는 모든 분야에 충분한 데이터가 존재하는 게 아니기 때문이다. 내가 최근 공부를 시작한 Sound Event Detection도 예외는 아니다. 이 분야에는 Strongly-labeled dataset, Weakly-labeled dataset, Unlabeled dataset 이렇게 세 종류의 데이터가 존재한다. 이들은 각각

- Strongly-labeled dataset: 어느 sound event가 어느 시간 구간에서 발생했는지 기록된 데이터셋

- Weakly-labeled dataset: 어느 sound event가 발생했는지는 기록되었지만, 어느 시간 구간인지는 기록되지 않은 데이터셋

- Unlabeled dataset: label(어느 sound event인가)이 없는 데이터셋

을 뜻하고, Strong-labeled dataset은 구축하는 데에 많은 시간이 들고 비효율적이어서 그 수가 부족하다.

 

 그렇기 때문에 Unlabeled dataset을 어떻게 하면 지도학습에 활용할 수 있을지에 대한 고민이 필요했고, 그래서 등장한 개념이 Semi-supervised learning이다. 그리고 이 과정에서 오버피팅을 막고 unlabeled data를 효율적으로 사용하기 위해 일종의 규제 기법이 필요하다.

 

 조금 더 구체적으로는, 사람은 특정 지각(percept)이 약간 변하더라도 여전히 그것을 동일한 객체로 간주하는 경우가 많다. 이와 마찬가지로, 분류 모델도 유사한 데이터 포인트에 대해 일관된 출력을 제공하는 규제 함수를 선호해야 한다. 이를 위해 데이터에 노이즈를 추가하는 기법이 논의되었고, 드롭아웃 기법도 논의되었다. 그러나 unlabeled examples에서는 classification cost가 정의되어 있지 않기 때문에 상기 기법은 적용하기 어렵다. 이를 극복하기 위해 Gamma model을 사용하기 시작했는데, 이 모델은 위에서 설명한 구조와 같이 스스로 타겟을 만들고 자기 이전 상태를 기반으로 학습하기 때문에 부정확할 수 있다는 문제가 있었다.

 

 타겟 퀄리티를 향상시키기 위한 약간 다른 접근 방식인 Pseudo-Ensemble Aggrement도 존재한다. 이 접근 방식을 띄는 대표적인 모델로 π-model이 있고, 더 개선된 버전인 Temporal Ensembling도 있다. 그런데 Temporal Ensembling 역시 한계가 있었고, 이 모델의 타겟은 한 에포크에 한 번 업데이트가 이루어지기 때문에 너무 느리다는 점을 본 논문에서 지적하고 있다.

 

 

Mean Teacher

 

대안으로 Mean Teacher를 제시하였다.

 

우선 작동 방식은 다음과 같다.

- 먼저 student model과 teacher model에는 각각 랜덤 노이즈 𝜼, 𝜼'이 적용되어 있다.

- 하나의 데이터 미니배치에 대하여 student model이 예측을 수행한다.

- 예측 결과에 따라 classification cost를 산출한다. (이 논문에서는 Cross-Entropy를 사용했다)

- teacher model은 자신의 파라미터를 student model의 파라미터의 EMA값을 기반으로 업데이트한다.

- teacher model 또한 student model이 예측한 동일한 데이터 샘플에 대하여 예측을 수행한다.

- student model과 teacher model 간 예측 결과를 비교하는 consistent loss를 계산한다. (이 논문에서는 MSE를 사용했다)

- student model은 classification cost와 consistent loss의 합Total loss로 업데이트한다.

- 모든 학습 과정이 종료될 때까지 처음부터 다시 반복한다.

 

 

<1> Consistent loss

 여기서 consistent loss를 사용하는 이유를 간단하게 소개하자면, student model에 비해 teacher model은 EMA를 기반으로 조금 더 부드럽게 업데이트되기 때문에 더 안정적인 모델이라고 본다. 그래서 student model이 이 teacher model을 따라가다 보면 예측 성능이 더 좋아질 수 있다고 본 논문은 주장하고 있다.

 

 또 consistent loss function으로 MSE, Gradient Descent 두 가지 방법을 고려한다. 그러나 MSE가 일반적으로 더 성능이 좋았다고 이야기하고 있고, 이 부분은 Appendix 더 읽어 보고 다시 써야지

 

 이제 Consistent loss를 수식으로 표현하면 다음과 같다.

자세히 들어가면,

- E_{x,η,η​}: x,η,η​에 대하여 내부 식의 기댓값(expectation)을 계산 (η,η​은 노이즈라고 위에서 언급되었음, x는 인풋)

- f(x,θ,η) : Teacher model의 출력. θ′은 EMA 기반 가중치

- f(x,θ,η): Student model의 출력. θ은 Total loss 기반 가중치

즉, Teacher model의 출력과 Student model 간 유클리드 거리 제곱 합에 대한 평균을 계산한다.조금 더 고민해서 파 보면 결국 MSE를 나타내고 있다.

 

 Pie-model은 θ와 θ'를 따로 구분하지 않고, Temporal Ensembling은 f(x,θ,η)을 연속적인 예측에 대한 가중평균이라고 근사하는 반면, Mean Teacher는 training step t일때  θ'_t 를 연속적인 θ의 EMA 가중치로 정의한다.

식으로 쓰면 이렇게, 𝜶는 smoothing hyperparameter.

 

 

실험 및 결과

 

두 데이터 모두 클래스는 10가지이다.

 

label이 얼마 되지 않을 때(즉 unlabeled dataset이 많을 때) 다른 모델에 비해 성능이 압도적으로 좋은 것을 확인할 수 있다.

 

다만 label 개수가 많아질 수록 때에 따라 Virtual Adversarial Training 기법이 더 좋아지는 경우도 있었고, 무엇보다 label이 많아질 수록 다른 모델과 성능 차이가 점점 좁혀지고 있다. unlabeled dataset 수가 얼마 되지 않는다면 그렇게까지 차별점을 느끼진 못 하지만, 많은 unlabeled data를 사용해야 할 때 이 방법이 효과적임을 잘 입증하고 있는 지표라고 볼 수 있겠다.

 

 

이 그림은 training step에 따른 classification error rate를 보여주고 있다.

Pie-model과 Mean teacher의 student model만 비교하더라도 확실히 student model이 더 낮은 에러율을 보이고 있다.

또 위에서 teacher model이 EMA 기반으로 업데이트되어 student model보다 안정적이라고 이야기했는데, 실제로 아래쪽 그래프를 보면 student model보다 teacher model이 에러율이 더 낮은 것을 볼 수 있다.

 

 

 

하이퍼파라미터를 이렇게저렇게 변경했을 때의 에러율 그래프이다.

 

중요한 것 몇 가지 소개하면

- 다른 기법에서 쓰이던 input noise와 dropout을 mean teacher에도 적용해야 하냐? 라는 질문에 그렇다 라고 대답할 수 있겠다. (a)를 보면 두 가지를 같이 썼을 때 에러율이 더 낮게 나타나고 있다. (augmentation도 효과적이었음을 보여주고 있다)

- EMA decay와 consistency cost weight는 일정 수준을 넘어서면 에러율이 갑자기 올라간다.

- 저자들은 MSE 대신 KL-divergence도 쓸 수 있지 않을까 고민했는데, 결과적으로 MSE가 조금 더 성능이 좋았던 것 같다.

 

 

 

Mean Teacher model을 사용하지 않는 다른 SOTA 모델들과 비교하였다.

ConvNet 기반 Mean Teacher는 딱히 성능이 더 좋진 않았고, ResNet 기반이 확실히 기존 모델보다 성능이 좋았다.

하지만 이것도 역시 labeled dataset만 사용했을 때보다는 낮은 성능이다(당연하지만).

 

 

결론

지금까지 내용으로 볼 때 Mean Teacher 접근법은 성공적이라고 할 수 있겠다.

성능은 물론, Temporal Ensembling과 비교할 때 1)이전 에포크의 모든 예측값을 저장해 둘 필요가 없고, 2)가중치는 타임스텝마다 그때그때 업데이트 되므로 대규모 데이터셋을 다룰 때나 온라인 학습에도 훨씬 효율적이다.

 

끝으로 저자는 Virtual Adversarial Training과 이 기법을 결합할 때 더 좋은 결과를 산출할 수 있을 것으로 기대하고 있다. 나중에 VAT도 한 번 알아보고 같이 사용할 수 있는 방안도 고민해 보면 좋을 것 같다.