일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | 3 | ||||
4 | 5 | 6 | 7 | 8 | 9 | 10 |
11 | 12 | 13 | 14 | 15 | 16 | 17 |
18 | 19 | 20 | 21 | 22 | 23 | 24 |
25 | 26 | 27 | 28 | 29 | 30 | 31 |
Tags
- Tree backup
- n-step
- Maximazation bias
- 온폴리시
- Actor-Critic
- MAML
- 파이썬 인터프리터 락
- 오프폴리시
- 병행성 제어
- Maximum entropy
- Control variate
- Concurrency Control
- Double learning
- docker tensorboard
- 인터프리터 락
- Global Interpreter Lock
- Off-policy
- Python Interpreter Lock
- Importance sampling
- Reinforcement Learning
- 강화학습
- 도커 텐서보드 연결
- Interpreter Lock
- Few-shot learning
- Meta Learning
- 통합 개발
- 지속적 개발
- Soft Actor-Critic
- 전역 인터프리터 락
- 중요도 샘플링
Archives
- Today
- Total
HakuCode na matata
[MAML] Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks 본문
Machine Learning/Meta Learning
[MAML] Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks
@tai_haku 2022. 6. 17. 17:11반응형
Abstract
- ‘메타러닝’의 목표 = 개별 과제들에 대해 소량의 데이터를 기반으로 효과적인 학습을 거듭하여 다수의 과제들에 대한 최적의 일반화 성능을 가진 파라미터 학습
- 메타러닝 알고리즘은 어떠한 모델에 대해서도 적용 가능한 메커니즘
- Supervised-learning과 Reinforcement-learning에서 실험한 결과, 기존 SOTA와 유사하거나 상회하는 성능을 보임
Introduction
- 소량의 데이터를 기반(Few-shot)으로 다수의 과제(Multiple tasks)를 빠르게 학습(fast adaptation & fine-tune)하는 알고리즘을 구성하고자 함
- 상기 학습 메커니즘 구현 간 어려움
- 데이터 부족(Lack of Data)
- (High bias) Overfitting 가능성 높음
- 다수의 과제(Multiple tasks)
- (Data type) Data type, size, domain 등의 과제 별 특성에 따른 데이터 형식에 차이 존재
- (Combination) 이를 이전 학습 결과와 통합해야 함
- 데이터 부족(Lack of Data)
- 따라서, 본 논문에서 경사하강법(Gradient descent)을 기반으로한 ‘메타학습 알고리즘’ 제시
- 신경망 파라미터를 기반으로 하는 학습방식에 따라 신경망 이외의 추가적인 파라미터 불필요
- 이러한 학습 메커니즘은 관점에 따라 아래와 같은 의미를 가짐
- Feature learning 관점 → General feature를 학습하는 과정
- Dynamical system 관점 → Sensitivity of Loss function를 최대화하는 과정
Model-Agnostic Meta-Learning
- 본 Section에서는 Few-shot learning 문제상황에서 Rapid adaptation을 학습을 목표로 함
Meta-Learning Problem Set-Up
- 구성요소
- $f$ → 모델(input: $x$, output $a$)
- $T=\{\mathcal{L}(x_1,a_1,...,x_H,a_H),\ q(x_1),\ q(x_{t+1}|x_t,a_t),H\}$ → 과제
- $\mathcal{L}$ → 손실함수(Loss function)
- $q(x_1)$ → 초기관측분포
- $q(x_{t+1}|x_t,a_t)$ → 상태천이분포
- $H$ → 에피소드 길이
- 학습과정에 대한 다이어그램(Fig 1.)
- 메타러닝 시나리오 상에서 본 논문은 전체 과제 $T$가 가지는 분포 $p(T)$를 고려함
- $p(T)$상에서 특정 과제 $T_i$를 추출
- $T_i$가 가진 $K$개의 샘플 추출
- $K$개의 샘플에 대한 손실함수(오차) $\mathcal{L}$ 계산
- 모델 $f$는 $q_i$로부터 나온 새 데이터에 대한 테스트 오차(test error)가 파라미터의 이동을 얼마만큼 만들어 내는지를 고려하며 학습함
- 이러한 테스트 오차(test error)는 사실상, 메타러닝의 훈련 오차(trainning error)로서의 역할을 함
- 마지막으로 이러한 메타러닝의 끝에서 새로운 과제들이 $p(T)$로부터 추출되고 $K$개의 추출된 샘플을 통해 학습된 모델에 의해 메타 성능(meta performance)을 측정함
A Model-Agnostic Meta-Learning Algorithm
- 알고리즘 의사코드
- 기존의 방법들은 RNN이나 Non-parametric기반의 Feature embedding 기법을 사용해 Fast-adaptation 구현
- 본 논문에서는 일반화 표현을 창출하기 위해 경사하강법(Gradient Descent) 기반의 학습 알고리즘 제시
- 구성 요소
- $\theta$ → 학습 파라미터
- $f_\theta$ → (학습 파라미터 기반) 모델
- $\alpha$ → 학습률
- $\theta_i'$ → $i$번째 과제에 대해 추출된 샘플을 기반으로 학습한 파라미터
- $J(\theta)$ → 목적함수
- 주의할 점은 메타 최적화(Meta optimization)는 $\theta$에 대해 수행되지만, 목적함수(Objective function) 값은 개선된 파라미터인 $\theta'$를 기반으로 계산된다는 것
- 메타 최적화는 확률적 경사하강법(Stochastic Gradient Descent)를 기반으로 수행됨
$$ \theta \leftarrow \theta - \beta\nabla_\theta\underset{T_i\sim p(T)}{\sum}\mathcal{L}{T_i}(f{\theta_i'}) $$
- 여기서의 $\beta$는 메타 최적화에 사용되는 학습률
- 이러한 메타러닝 방식은 경사를 통해 얻은 경사(Gradient through a gradient)를 포함하기 때문에, Hessian-vector계산을 위한 추가적인 역전파 과정이 필요함
- 본 논문에서는 이러한 계산과정의 유무에 따른 차이도 비교함
Species of MAML
- 본 Section에서는 Supervised-learning과 Reinforcement-learning에 적용한 예시에 대해 알아봄
- 이는 각각 학습법에 차이에 따라 3데이터 생성방법이나 표현방법에 차이가 존재하나, 적용 메커니즘 자체는 동일
Supervised Regression and Classification
- 본 논문에서는 회귀(Regression) 문제를 풀기 위해 손실함수로서 MSE(Mean-Squared Error)를 사용
$$ \mathcal{L}{T_i}(f\phi)=\underset{x^{(j)},y^{(j)}\sim T_i}{\sum}||f_{\phi}(x^{(j)}),y^{(j)}||^2_2 $$
- $(x^{(j)},y^{(j)})$는 각각 (입력 / 출력) 으로서 j번째 입출력 쌍을 의미함
- 이산 분류(Discrete Classification) 문제에 대해서는 손실함수로서 Cross-Entropy를 사용
$$ \mathcal{L_{T_i}}(f_\phi)=\underset{x^{(j)},y^{(j)\sim {T_i}}}{\sum}{y^{(j)}}logf_\phi(x^{(j)})+(1-y^{(j)})log(1-f_{\phi}(x^{(j)})) $$
- 상기한 손실함수를 기반으로 아래와 같은 의사코드를 통해 메타러닝 구현
- 해당 알고리즘을 통해 전형적인 N-way K-shot classification 학습 구현
Reinforcement learning
- RL에서의 메타러닝 목표 = 새 정책(Policy)에 대한 Fast-adaptation
- 과제 별 목표는 달라질 수 있지만, 메타를 빠르게 학습하기 위한 취지
- 구성 요소
- $f_\theta$ → (학습 파라미터 기반) 모델 → 상태 $x_t$를 행동분포 $a_t$로 사상(mapping)
- $\mathcal{L}{T_i}(f\phi)$ → 손실함수
- 상기한 손실함수를 기반으로 아래와 같은 의사코드를 통해 메타러닝 구현
- 해당 알고리즘을 통해 강화학습 에이전트 구현
- 학습 간 K개의 데이터 샘플을 사용하기 위해 다음과 같은 요소들을 사용함
- 모델 $f_\theta$
- 과제 $T_i=\{x_1,a_1,...,x_H\}$
- 보상 $R_i=R_i(x_t,a_t)$
- 환경의 Dynamics가 알려지지 않은 상태이기 때문에, 기대보상($G$)에 대한 미분이 불가
- 따라서, 모델 경사 및 메타러닝 경사를 계산하기 위해 정책경사(Policy Graidient) 방법을 사용
- 또한, 이러한 PG방식의 알고리즘은 On-Policy 알고리즘이기 때문에, 모델 적용 간 추가적인 경사하강을 위해 현재 정책($f_{\theta_{i}'}$)으로부터 추출된 새로운 샘플이 필요
- 이 과정에서 RL과제에 대한 샘플의 길이는 종전의 지도학습과 다르게 5~8정도의 샘플(Transition) 길이를 갖음
Related Work
- 이와 관련한 기존 연구로는 ‘가중된 초기화 파라미터 사용’이나 ‘학습 네트워크의 최적화’에 중점을 둠
- 이와 다르게 본 연구는 추가적인 파라미터 없이 경사하강법 기반의 알고리즘임
- 또한, 기존의 few-shot 학습은 생성모델링(Generative modeling)이나 이미지 인식(Image recognition)에 특정하여 발전되어왔음
- 허나, 여기서 사용하는 방식은 RL로의 확장이 어려움
- 또다른 학습 방식으로는 ‘다중 과제를 위한 확장된 메모리 기반 학습’이 있음
- 하지만, 본 알고리즘이 성능이 더 우수
- 추가적으로 우수한 초깃값 설정, 학습자와 메타학습 간 사용 메커니즘이 모두 경사하강법으로 동일하기에 직관성 우수
- 본 알고리즘은 신경망 초기화 메커니즘과도 유관
- Vision 분야에서 빅데이터 기반 사전훈련모델을 통해 우수한 초깃값 설정의 중요도는 높음
- 본 알고리즘은 새 과제에 대해 민감도가 높은(적응력이 우수한) 초깃값으로 설정 가능
Experimental Evaluation
- 본 논문의 실험평가요소 3가지
- MAML은 빠른 학습이 가능한가?
- Regression 과제를 통해 MAML이 빠른 적응(Adaptation)이 가능한 위치로 매개변수를 최적화하며, 이는 개별 과제가 아닌 전체 과제에 대한 분포 $p(T)$의 손실함수에 민감하다는 것을 나타냄
- Classfication 과제를 통해 2계도 미분 과정 없이 메타 경사기반 효과적인 학습이 가능함을 확인함으로써 Hessian-vector에 대한 연산 과정이 생략되어 33% 정도의 연산시간 축소
- MAML이 Supervised-Learning이나 Reinforcement-Learning 등 다양한 도메인에 사용될 수 있나?
- Classification 과제에 대해 SOTA와 필적하거나 일부 구간에서 상회하는 성능을 보임 → 또한, 기존 SOTA는 다양한 도메인에 적용이 어려우나, MAML은 그렇지 않음
- MAML로 학습된 모델이 추가적인 경사나 샘플기반 갱신이 가능한가?
- RL-2D Navigation 과제에서 앞선 1, 2번 요소들과 더불어 추가적인 경사에 의해 지속적인 개선이 가능함을 확인
- MAML은 빠른 학습이 가능한가?
Discussion and Future Work
- 본 연구에서 사전학습모델없이 적용이 가능한 메타러닝 기법 소개
- 이 방법은 어떠한 represention에 대해서도 적용이 가능하며 어떠한 도메인의 과제도 적용 가능
- 또한, 특정 가중된 초기화 방법 등을 사용하지 않아 어떠한 데이터량이나 경사량에 대해서도 적용가능
- 이러한 연구는 메타러닝을 위한 초석이 될 것임
반응형
Comments