💡 AI/토이 프로젝트

GAN (Generative Adversarial Network)

U-chan Seon 2022. 6. 14. 20:29

GAN (Generative Adversarial Network) 

GAN은 Data를 만들어내는 Generator와 만들어진 Data를 평가하는 Discriminator가 서로 대립(Adversarial)적으로 학습해가며 성능을 점차 개선해 나가자는 개념이다. 여기서 GAN의 목표는 Generator 를 잘 학습시키는 것이다.

 

ML은 input이 데이터, output은 label 인 정형화된 틀이 있다.

그러나 input이 노이즈, output이 data 인 것도 있다. 그게 바로 GAN이다.

 

두 개의 네트워크가 경쟁을 해가면서 학습을 한다? 그 당시에 존재하지 않았던 개념이다. 딥러닝의 아버지라고도 불리우는 Yann Lecun 교수는 GAN을 가리켜 최근 10년 동안 머신러닝 분야에서 가장 혁신적인 아이디어라고 말했다고 한다.

 

GAN은 겨울 이미지 -> 여름 이미지, 여름 이미지 -> 겨울 이미지로 바꿔주는 것도 한다.


GAN의 구조

위조지폐범 GAN

지폐위조범(Generator) 경찰을 최대한 열심히 속이려고 하고 다른 한편에서는 경찰(Discriminator) 이렇게 위조된 지폐를 진짜와 감별하려고(Classify) 노력한다. 어느 순간이 되면 지폐 위조범의 능력이 극에 달해 진짜와 다름없는 가짜 지폐를 만들게 됩니다. 이 때 경찰의 구별 확률은 로 수렴합니다. 이 상태에서 만들어진 가짜 지폐는 진짜 지폐처럼 보일 것이다.

이런 경쟁 속에서 그룹 모두 속이고 구별하는 서로의 능력이 발전하게 되고 결과적으로는 진짜 지폐와 위조 지폐를 구별할 없을 정도(구별할 확률 p=0.5) 이른다.

 

GAN structure

 

GAN이라는 것은 이름 그대로 세 단계로 구분된다.

  • Generative: 가짜 이미지를 생성하는 Generative model G가 있다.
  • Adversarial: G는 진짜와 가짜 이미지를 구별하는 Discriminative model 와 적대적으로 대립하며 각자 학습해나간다.
  • Network: G D는 Neural network이다.

 

 

그런데 latent space (z) 라는 개념이 나오게 된다. 지폐 위조범이 지폐를 만들 때, 종이가 필요할 것이다. 이 종이는 그냥 임의의 종이를 이용할 것이다. 실제 GAN의 구현에서도 종이의 역할을 하는 noise가 필요하다. 즉, Generator가 하는 일(가짜데이터를 만드는 일) 은 noise로부터 진짜 이미지로 맵핑하는 것이라고 볼 수 있다.

Genrative, Discriminative Distribution

 

위 그림에서 (a)-(d) 각 그림은 시간의 순서를 나타냅니다. 각 그림의 하단부는 noise 로부터 진짜 이미지 x로 맵핑하는 것을 나타낸다. 그리고 상단부에서 검정색 점들의 분포는 실제 이미지의 분포를, 초록색 실선Generator가 만들어낸 가짜 이미지의 분포를 나타낸다. 시간이 흐름에 따라 (학습이 이루어짐에 따라) 점차 초록색 실선이 실제 분포에 fitting되는 것을 확인할 수 있다.

 

파란색 점선Discriminator의 예측 결과이다. 처음에는 아무렇게나 응답하다가, 점점 진짜와 가짜를 구별 잘 하더니, 마지막 (Generator가 완벽해진 순간)에는 모든 예측 확률을 2로 내놓는 것을 확인할 수 있다.

 

Latent space in GAN

 

단순한 MNIST부터 얼굴 이미지, CIFAR-10 이미지까지, GAN을 통해 Training sample과 거의 비슷한 Generaled sample을 얻은 것을 확인할 수 있다.

 

여기서 한 가지 더 주목할 것은 하단의 Linear interpolation입니다. latent space(z)에서 interpolation (보간법)을 진행하면 생성된 결과물이 부드럽게 이어지는 것을 확인할 수 있다. 즉, latent space에 따른 mapping이 단순히 1:1 mapping을 기억하고 있는 것이 아니라는 것을 알 수 있다.


GAN의 학습

GAN structure with UNet

noise인 z가 들어오고 미분 가능한 Function G가 네트워크이다. 이 모델에서 Fake 데이터 G(z)를 만들고, 가짜데이터 G(z)를 D에 넣는다. 그럼 D가 진짠지 아닌지 판별하려고 노력한다.

 

- Discriminator(D)output : 확률 / Generator (G) output : data

D(G(z)) = 0.3 : 가짜 데이터인 G(z)를 판별했는데 진짜일 확률은 0.3이다.

 

- Discriminator를 학습시킬 때에는 D(x)1이 되고 D(G(z))0이 되도록 학습시킨다.

진짜를 진짜라고 판별하게끔 D(x) = 1 이 되게끔 학습시킨다.

(진짜 데이터를 진짜로 판별하고, 가짜데이터를 가짜로 판별할 수 있도록)

 

- Generator를 학습시킬 때에는 D(G(z))1이 되도록 학습시킨다.

(가짜 데이터를 discriminator가 구분못하도록 학습, discriminator를 헷갈리게 하도록)

 


Discriminator의 학습

Feed forward process of Discriminator

Fake 데이터를 0, Real 데이터를 1로 설정하고 Discriminator를 학습시킨다.


Generator의 학습

Feed forward process of Generator

Fake 데이터를 1로 설정하고 Discriminator를 학습시킨다. 즉, Discriminator가 Fake 데이터를 진짜로 생각하게끔 학습시킨다.

 

Generator의 Feed forward 과정은 noise를 Fake 데이터로 만드는 것에서 끝나는 것이 아니라, Discriminator에 넣는 것까지를 포함한다. 일반적인 Neural Network는 input을 네트워크에 넣어서 output이 나오는 것까지인데 Generator는 Discriminator에 넣는 것 까지를 말한다.

왜 그럴까?


Generator의 Back propagation 과정

 

Generator의 입장에서는 Value function을 minimize 시켜야 한다.

즉, D(G(z))가 1이 되게끔 학습시켜야 한다.

Generator만을 가지고 그 Generator를 학습시킬 수 없다. 

 

이상한데? G를 학습시켜야 하는데, D를 학습시켜야 한다? 이게 말이 되나?

 

Discriminator의 Back propagation

 

64차원의 Mnist 데이터 500개를 받아서 28x28(=784)개

 

Generator의 Back propagation 과정을 보자.

GAN은 Discriminator를 먼저 학습시키고 Generator를 학습시킨다. Generator의 weight를 업데이트 시키려면 Discriminator 에서 부터 Back propagation 시켜야한다. 결국 Discriminator의 weight를 가져와서 학습시켜야 한다. 여기서 Discriminator의 weight는 업데이트 시키지 않는다. 

 

GAN의 목적은 Discriminator를 학습시키는게 아니라 Generator를 학습시키는 것이다.

Discriminator를 헷갈리게 학습시킨다는 게 Discriminator를 건드리겠다는 것이 아니라, Discriminator의 error를 가지고 전파시켜서 Generator의 weight를 업데이트 시키겠다는 것이다. 결국 Generator 입장에서는 Discriminator가 고정되어 있지만 weight를 업데이트 시키면 Discriminator가 fake 데이터를 진짜로 생각할 것이다라는 개념이다.


GAN 학습 과정

  1. 먼저 노이즈를 생성하고, Generator에 넣어서, G(Z)가 0이 되게끔, Real data가 1이 되게끔 Discriminator를 학습시킨다.
  2. 그 다음 노이즈를 생성하고 G(Z)가 1이 되게끔 Generator를 학습시킨다. 여기서 Discriminator는 건들이지 않는다.
  3. 이것을 계속 반복한다.

 

Value function

GAN은 D와 G에 대하여 다음과 같은 value function을 minmax problem으로 풀게 된다.

  • D의 입장 : D(x)가 1이고 (진짜 데이터를 1로 구분) D(G(z))가 0일 때 (가짜 데이터를 0으로 구분) V는 최대값 
  • G의 입장 : D(G(z))가 1일 때 (가짜 데이터를 1로 속임) V는 최소값 
  • D의 입장에서 maximize 시키고 G의 입장에서 minimize 시킨다.

 

 

 

 

 


GAN의 학습 결과

 

0.5로 수렴하는 것을 볼 수 있다.

 


GAN의 한계

GAN은 초기 모델인만큼, 한계점이 명확하다. 일단 뭔가 비슷한게 생성되는 것 같긴 한데, 아직은 좀 부족한 느낌이 강하다. 실제로 MNIST와 같이 단순한 이미지는 잘 생성하지만, 동물같이 복잡한 이미지는 만들어내기 힘들어 하는 것을 볼 수 있다.

 

하지만 이보다 더 큰 문제는, 학습의 안정성이 떨어진다는 점이다.

 

GAN이 갖는 불안한 안정성의 대표적인 예시로 Mode collapse가 있는데, 이는 학습의 다양성이 떨어지는 것을 말한다. 즉, 지폐 위조범의 예시에서 한 번 진짜라고 여겨진 위조지폐가 있다면, 다양한 지폐를 만들지 않고 이 지폐만 대량 생산해버리는 것을 생각하면 된다.

 

Discriminator 입장에서는 진짜 데이터를 1로, 가짜 데이터를 0으로 받으면 optimal이고,

Generator 입장에서는 가짜 데이터를 1로, Discriminator를 속이기만 하면 된다.

 

문제가 되는 부분은 Generator 이다. 

 

결국 noise를 mnist에 맵핑되는 Generator를 만드는 것이 최종적인 목표인데, 그러면 noise와 mnist의 분포를 일치시키는 과정이 Generator를 학습시키는 과정이라고 볼 수 있다. 0~9까지 균일하게 생성해야 잘 만들어진 Generator라고 할 수 있다. 그런데, 위에서 말했듯이 Generator 입장에서는 Discriminator를 속이기만 하면 된다. 

 

그래서 말도 안되는 숫자들을 생성한다. 근데 Discriminator 속는다. 현상이 바로 Mode collapse이다.


Ref

GAN : https://tyami.github.io/deep%20learning/GAN-1-theory-GAN-DCGAN/

 

.