Wasserstein-1(WGAN)

DCGAN에서 크로스 엔트로피 손실 함수를 사용했을 떄, 잘 동작하지 않는다.
훈련 성능을 높이기 위해서 진짜와 가짜 이미지 분포 사이의 바서슈타인-1(Wasserstein-1)거리(EM 거리, Earth Mover)
기반의 손실 함수를 이용한다.
두 분포 P(x), Q(x) 사이의 거리(dissimilarity)를 측정할 수 있는 방법
TV(Total Variation)
sup(S)는 S의 모든 원소보다 큰 가장 작은 값을 의미한다: S의 최소상계(least upper bound)
EM distance(Earth mover’s distance)
TV와 반대로 inf(S)는 S의 모든 원소보다 작은 가장 큰 값을 의미한다: S의 최대 하계(greatest lower bound)

- TV거리는 각 포인트에서 두 분산 사이의 가장 큰 차이를 측정
- EM거리는 한 분포에서 다른 분포로 변화할 때, 필요한 최소의 작업량으로 해석할 수 있다.
   EM거리는 P와 Q의 결합 확률 분포 집합에서 계산된다.
- KL(콜백-라이블러)와 JS(젠슨-섀넌) 발산은 정보 이론 분야에서 유래되었다.
   KL 발산은 대칭은 아니다(KL(P || Q) != KL(Q || P)) JS 발산은 대칭
ex)
KL 발산 & 크로스 엔트로피
KL 발산 KL(P || Q)는 참조 분포 Q에 대한 분포 P의 상대적인 엔트로피를 측정한다.
이산이나 연속분포에서 풀어 쓴 고식을 기반으로 KL 발산을 P, Q사이의 크로스 엔트로피에서 P 자체의 엔트로피를 뺀 것으로
볼 수 있다.
KL(P || Q)=H(P, Q)-H(P)
원본 GAN에서 사용한 손실 함수는 진짜와 가짜 샘플 사이의 JS 발산을 최소화하는 것이다.
JS(P|Q)=-(H(P)+H(Q))/2
하지만 JS는 GAN 모델을 훈련하는데 문제가 있다.
http://proceedings.mlr.press/v70/arjovsky17a/arjovsky17a.pdf
EM distance
한 직선은 x=0 다른 직선은 x=theta에 있고 x축을 따라 이동한다고 하면,(theta>0)
KL, TV, JS 거리는 각각 KL(P||Q)=infinite. TV(P, Q)=1, JS(P, Q)=log2로
P와 Q가 서로 비슷해지도록 theta에 대하여 미분할 수가 없다.
EM 거리는 EM(P, Q)=abs(theta)로 theta에 대한 그레이디언트가 존재하고 Q를 P쪽으로 이동시킬 수 있다.

EM거리는 계산하는 것은 그 자체가 최적화 문제로 계산이 매우 어렵다.
EM거리 계산을 칸토로비치-루빈스타인 쌍대성(Kantorovich-Rubinstein duality)이론을 사용하여 단순화하여 사용한다.
상한은 1-립시츠(1-Lipschitz) 연속함수에 대해 적용
분포 사이 거리 측정, 1-립시츠등 참고
머신러닝교과서with파이썬,사이키선,텐서플로_개정3판pg.755-758
GAN에서 EM 사용
심층 신경망을 이용하면, 어떤 함수도 근사할 수 있다.
신경망 모델을 훈련하여 바서슈타인 거리 함수를 근사할 수 있다.

기본 GAN에서는 판별자를 분류기 형태로 사용한다.
WGAN에서는 판별자를 바꾸어 확률 점수 대신에 스칼라 점수를 반화하는 비평자(critic)로 바꿀 수 있다.
비평자 함수의 1-립시츠 성질을 훈련하는 동안 유지해야 한다.
—> 가중치를 작은 범위 예를 들어 [-0.01, 0.01] 사이로 클리핑(clipping)한다.
그레이디언트 페널티(Gradient Penalty, GP)
WGAN 논문에서는 판별자(비평자)의 1-립시츠 성질을 위해 가중치 클리핑을 제안했다.
하지만 한 논문에서 가중치 클리핑이 그레이디언트 폭주와 소실로 이끌고 가중치 클리핑으로 인해 충분한 성능을 내지 못함을 밝혔다.
(간단한 함수만 학습할 수 있음)

그레이디언트 페널티(Gradient Penalty, GP)
WGAN-GP
1. 한 배치에서 진짜와 가짜 샘플의 각 쌍에 대해 균등 분포에서 랜덤한 수를 샘플링한다.
2. 진짜와 가짜 샘플 사이를 보간한다.->보간된 샘플의 배치가 만들어진다.
3. 보간된 저체 샘플에 대해 판별자 출력을 계산한다.
4. 각 보간된 샘플에 대해 비평자 출력의 그레이디언트를 계산한다.
5. GP를 계산한다.

판별자의 총 손실은 진짜 손실 + 가짜 손실에 lambda가 곱해진 그레이디언트 페널티의 합이다.
(lambda는 튜닝 가능한 하이퍼파라미터)