负损失函数: 能量模型视角
能量模型的视角可以很好地连接回KL损失, 我们知道KL损失在优化上等价于交叉熵,
因为数据分布的熵一般不受模型影响. 做一定变换,我们观察一下负损失函数关于参数的梯度,可以发现很有趣的性质
\[\begin{split}\begin{align}
L(m) &= - D_{KL}(p_{data}(x) || q_m(x))\\
&= c+ E_{p_{data}}[\log q_m(x)] \\
q_m(x) &= {1\over Z_m} \exp U_m(x) \\
\log q_m(x) &= - \log Z_m + U_m(x) \\
&= - \log \int_x \exp U_m(x) + U_m(x) \\
{\partial \over \partial m }L(m) &= 0 + E_{p_{data}}[{\partial \over \partial m } \log q_m(x)]\\
&= E_{p_{data}}[{\partial \over \partial m } (- \log \int_x \exp U_m(x) + U_m(x))] \\
&= E_{p_{data}}[ (- \int_x { \exp U_m(x) \over \int_x \exp U_m(x)} {\partial \over \partial m } U_m(x) + {\partial \over \partial m } U_m(x))] \\
&= E_{p_{data}}[ (- \int_x q_m(x) {\partial \over \partial m } U_m(x) + {\partial \over \partial m } U_m(x))] \\
&= E_{p_{data}}[ - E_{q_m(x)} [ {\partial \over \partial m } U_m(x)] + {\partial \over \partial m } U_m(x))]\\
&= - E_{q_m(x)} [ {\partial \over \partial m } U_m(x)] + E_{p_{data}}[ {\partial \over \partial m } U_m(x))] \\
&\neq {\partial \over \partial m } \left( - E_{q_m(x)} [ U_m(x)] + E_{p_{data}}[ U_m(x))] \right ) \\
L(m) &\neq c - E_{q_m(x)} [ U_m(x)] + E_{p_{data}}[ U_m(x))]
\end{align}\end{split}\]
让我们退回到GAN的范式上来,既然 \(q(x)\) 上采样困难,我们就可以对其做KL变分近似
\[\begin{split}\begin{align}
{\partial \over \partial m} L(m) &= - E_{q_m(x)} [ {\partial \over \partial m } U_m(x)] + E_{p_{data}}[ {\partial \over \partial m } U_m(x))] \\
&\approx - E_{q_f(x)} [ {\partial \over \partial m } U_m(x)] + E_{p_{data}}[ {\partial \over \partial m } U_m(x))] \\
L_1(m,f) &= -E_{q_f(x)} [ U_m(x)] + E_{p_{data}}[ U_m(x))] \\
L_2(m,f) &= - D_{KL}(q_f(x)||q_m(x) ) \\
&= E_{q_f(x)}[\log q_m(x)] - E_{q_f(x)}[\log q_f(x)] \\
&= E_{q_f(x)}[U_m(x)-\log Z_m] - E_{q_f(x)}[\log q_f(x)]
\end{align}\end{split}\]
L1和L2的冲突在于其中有一项的正负号是相反的.
按照对抗思想,优化L1时只调整m,优化L2时只调整f. L1在L2尽量大的情况下才有意义.
\[\begin{split}\begin{align}
L_3(m,f) &= E_{p_{data}(x)}[U_m(x)]-E_{q_f(x)}[ U_m(x)] + E_{q_f(x)}[\log q_f(x)] \\
\nabla_m L_1(m,f)&= \nabla_m L_3(m,g) \\
&= \nabla_m \left(-E_{q_f(x)} [ U_m(x)] + E_{p_{data}}[ U_m(x))] \right)\\
\nabla_f L_2(m,f) &= -\nabla_f L_3(m,g) \\
&= \nabla_f \left( E_{q_f(x)}[U_m(x)-\log Z_m] - E_{q_f(x)}[\log q_f(x)]\right)\\
\end{align}\end{split}\]
根据梯度对应,写出minimax形式到这里为止,对于minimax现象的推导就结束了
\[
L = \max_m \min_f E_{p_{data}(x)}[ U_m(x)]-E_{q_f(x)}[U_m(x)] + E_{q_f(x)}[\log q_f(x)]
\]
负损失函数: Vanilla GAN的minmiax自我博弈框架
注:这段推导的构造性不是很强
GAN模仿了一个自我min-max博弈的过程.局部的有一个类似交叉熵的形式
\[\begin{split}\begin{align*}
L &=\min_G \max_D E_{p_{data}(x)}[\log D(x)]
+ E_{q(z)}[\log(1-D(G(z)))] \\
&=\min_G \max_D E_{p_{data}(x)}[\log D(x)]
- E_{q(z)}[\log({1 \over 1-D(G(z))})] \\
\end{align*}
\end{split}\]
\[\begin{split}\begin{align*}
L &=\min_G \max_D E_{p_{data}(x)}[\log D(x)]
+ E_{q(z)}[\log(1-D(G(z)))] \\
&= \min_G \max_D \sum_{x} p_{data}(x) \log D(x) + \sum_{z} q(z) \log(1-D(G(z)))
\end{align*}
\end{split}\]
这里\(D(x)\)定义了一个Bernoulli分布的概率函数. 我们可以记Bernoulli分布 \(g(k=1|x,D)=D(x)\),
得到
\[\begin{split}\begin{align*}
L
&= \min_G \max_D \sum_{x} p_{data}(x) \log g(1|x,D) + \sum_{z} q(z) \log g(0|G(z),D) \\
&= \min_G \max_D CE(p_{data}(x), g(1|x,D) ) + CE( q(z), g(0|G(z),D) )
\end{align*}
\end{split}\]
这个约束是直接施加在条件分布g上的,所以严格意义上是归一化的CE. 考虑一个概率性的解码器 \(h_G(x|z)\),
感觉这里对Bernoulli分布取log是一种下界,暂且不讨论未log的形式. 调整一下求和顺序,加以归一化,
\[\begin{split}\begin{align*}
L
&= \min_G \max_D \sum_{x} p_{data}(x) \log g(1|x,D) + \sum_{z} q(z) \sum_x h_G(x|z) \log g(0|x,D) \\
&= \min_G \max_D \sum_{x} p_{data}(x) \log g(1|x,D) + \sum_x \sum_{z} q(z) h_G(x|z) \log g(0|x,D) \\
&= \min_G \max_D \sum_{x} \left[ p_{data}(x) \log g(1|x,D) + \sum_{z} q(z) h_G(x|z) \log g(0|x,D) \right ] \\
&= \min_G \max_D \sum_{x}
(p_{data}(x) + \sum_{z} q(z) h_G(x|z))
\left[ \begin{aligned}
{p_{data}(x)
\over p_{data}(x) + \sum_{z} q(z) h_G(x|z)
}
\log g(1|x,D) \\+
{\sum_{z} q(z) h_G(x|z)
\over p_{data}(x) + \sum_{z} q(z) h_G(x|z)
}
\log g(0|x,D) \end{aligned} \right ] \\
&= \min_G \max_D \sum_{x}
f(x|D,G,q)
\left[
c(k=1|x,q(z),h_G)
\log g(1|x,D) +
c(k=0|x,q(z),h_G)
\log g(0|x,D) \right ] \\
&= \min_G \max_D \sum_{x}
f(x|D,G,q) \sum_{k\in\{0,1\}}
c(k|x,q(z),h_G) \log g(k|x,D) \\
\end{align*}
\end{split}\]
\[\begin{split}
\left\{ \begin{aligned}
L &= \min_G \max_D \left[ \sum_{x} f(x|D,G,q)
CE_k\left[c(k|x,q(z),h_G), g(k|x,D) \right] \right] \\
f(x|D,G,q)&= p_{data}(x) + \sum_{z} q(z) h_G(x|z) \\
c(k=1|x,q(z),h_G) &= {p_{data}(x)
\over p_{data}(x) + \sum_{z} q(z) h_G(x|z)
} \\
g(k|x,D) &= D(x)
\end{aligned}
\right.\end{split}\]
注意以上使用的都是恒等变换,没有任何近似的成分.原式是从AAE的论文里quote来的.
一般来讲 \(g_D(k|x)\) 是已经归一化了的, 这个形式就成为了c分布和\(g\)分布之间的交叉熵,其中c描述的是给定x后,样本来自数据还是生成器的概率. \(g\)表示,给定x和判别器D后,样本被贴标签的概率. 我们知道,交叉熵是负KL加上熵,因此这个损失要求c分布在尽量靠近g分布的同时要保留尽量大的熵.
这个玩意要直接优化的话,得在\(x\)上面求和,按照概率 \(f\),进行采样,也就是要按照混合分布进行采样, 对于不在样本里的 \(p_{data}(x)\) 概率,是否应该考虑一下平滑呢?有了解的小伙伴可以讨论一下,总不能直接设置成0吧是不是太粗暴了一点? 对于在样本里的图片,需要近似来自G的生成概率. 然后还要考虑看看DG之间的博弈问题.
还有一个问题就是minmax应该如何去优化? 可能要从优化论里找一些工具.OGAD
似乎有一些好一点的性质?但是这里的损失函数是全局的,并没有一个显然的方法
把算符扔到求和内部去
不得不说GAN的损失函数确实是很有意思的,巧妙地运用Bernoulli分布
构造了一个有意义的优化目标.
草稿
对于负相,其中最后几步即便运用了log技巧,我们也没法把第一项的微分号移到外面
\[\begin{split}\begin{align}
{\partial \over \partial m} L(m) &= - E_{q_m(x)} [ {\partial \over \partial m } U_m(x)] + E_{p_{data}}[ {\partial \over \partial m } U_m(x))] \\
&= - \int_x q_m(x) {\partial \over \partial m } U_m(x) + f_m \\
&= - [ \int_x q_m(x) {\partial \over \partial m } U_m(x) + \int_x U_m(x) {\partial \over \partial m } q_m(x)
- \int_x U_m(x) {\partial \over \partial m } q_m(x)
]+ f_m \\
&= - [ \int_x {\partial \over \partial m } (q_m(x) U_m(x))- \int_x U_m(x) {\partial \over \partial m } q_m(x)
]+ f_m \\
&= - [ {\partial \over \partial m } \int_x (q_m(x) U_m(x))- \int_x U_m(x) q_m(x) {\partial \over \partial m } \log q_m(x)
]+ f_m \\
&= - {\partial \over \partial m } E_{ q_m(x)}[U_m(x)] -E_{q_m(x)}[ U_m(x) {\partial \over \partial m } \log q_m(x) ] + {\partial \over \partial m } E_{p_{data}}[ U_m(x))]
\end{align}\end{split}\]
这把KL的梯度转化成了最大化样本能量并且最小化变分分布上的能量.最小化的这个要求,
来自于能量函数的归一化. 能量函数 \(U_m(x)\) 一般对应GAN里面的判别器网络,只要输出一个实数就够用了.如果考虑 \(U_m(x)\) 关于x的梯度,就可以联系到梯度对齐(SM, 见CATSMILE-1025)的思路上. 注意我们呢可以把上式变换成一个积分形式. 我们可以看到,对于 \(p_{data}(x)>q_m(x)\) 的地方, \(U_m(x)\) 应当尽量大,反之,则应当尽量小. 这个形式有一点类似于Wasserstein散度的对偶形式,但是对偶要求在 \(U_m\) 上取上界,暂且先不考虑微分号转换造成的额外项目
\[\begin{split}\begin{align}
L(m) &= c - E_{q_m(x)} [ U_m(x)] + E_{p_{data}}[ U_m(x))] \\
&= c - \int_x q_m(x) U_m(x) + \int_x {p_{data}(x)} U_m(x)) \\
&= \int_x (p_{data}(x) - q_m(x)) U_m(x)
\end{align}\end{split}\]
尝试分解
回到期望之差的形式,难点在于 \(q_m(x)\) 上的采样很困难,于是我们自然地想到能否
用一个变分分布来对这个分布进行采样,考虑加入一个生成器分解 \(q_m(x) = \sum_z q_w(z) q_r(x|z)\) ,其中 \(q_w(z)\) 一般为简单的高斯先验,这样就会导致一个合理的能量定义,
但是头疼的是这个能量函数就又不好计算了. 不管怎么样,我们把能量函数看成是生成
模型的对数后,可以确保归一化条件 \(Z_m=1\) .此时求能量的梯度会需要对后验分布进行采样,且主要对样本上的梯度比较重要
\[\begin{split}\begin{align}
U_m(x) &=\log \sum_z q_w(z) q_r(x|z) \\
{\partial \over \partial m} U_m(x)
&= \int_z q(z|x) {\partial \over \partial m} \log q_w(z) q_r(x|z) \\
&= E_{q(z|x)}[ {\partial \over \partial m} \log q_{wr}(x,z)] \\
{\partial \over \partial m} L(m) &= - E_{q_m(x)} [ E_{q(z|x)}[{\partial \over \partial m } \log q_{wr}(x,z)]] + E_{p_{data}}[ E_{q(z|x)}[{\partial \over \partial m } \log q_{wr}(x,z)]] \\
&= -E_{q_{w}(z) q_r(x|z)}[\log q_{wr}(x,z) ] + E_{p_{data}(x) q_{wr}(z|x)}[\log q_{wr}(x,z)]
\end{align}\end{split}\]
我们发现第一项是很好采样估算的,但是第二项由于编码分布 \(q_{wr}(z|x)\) 涉及贝叶斯求逆,并不是很好求出. 推了这么多,我们回顾一下和原始KL形式的关系.
\[\begin{split}\begin{align}
{\partial \over \partial m} L(m) &= {\partial\over\partial m}E_{p_{data(x)}}[\log q_m(x)] \\
&= E_{p_{data(x)}}[ {\partial\over\partial m} U_m(x)] \\
&= E_{p_{data(x)}}[ E_{q(z|x)}[ {\partial \over \partial m} \log q_{wr}(x,z)]]
\end{align}\end{split}\]
这么一对比,我们发现,第一项似乎退化成零了,这是由概率归一化所保证的. 那么这样的情况下,可能就要用GAE方法去求解模型了
\[\begin{split}\begin{align}
-E_{q_{w}(z) q_r(x|z)}[ {\partial \over \partial m} \log q_{wr}(x,z) ] &= -\int_{x,z} q(x,z) {\partial \over \partial m} \log q(x,z) \\
&= -\int_{x,z} {q(x,z)\over q(x,z)} {\partial \over \partial m} q(x,z) \\
&= -\int_{x,z} {\partial \over \partial m} q(x,z) \\
&= -{\partial \over \partial m} \int_{x,z} q(x,z) \\
&= -{\partial \over \partial m} 1 \\
&= 0
\end{align}\end{split}\]
\(L_3\) 的所有梯度
最后再附上另外几个梯度, 我们居然还是绕不过求后验分布 \(q_{wr}(z|x)\) .真是越发地复杂了. 关键呢,
这个对于生成器f的导数,由于是狄拉克分布,还不是很好求. 一个简单的办法是通过一个高斯加噪扩展狄拉克分布, 后续也可以 \(q_w(z)\) 上注入一点噪音,来设法近似梯度 \(\nabla_f \log q_f(x)\)
\[\begin{split}\begin{align}
{\partial \over \partial m} L_1 &= - E_{q_f(x)} [ {\partial \over \partial m} U_m(x)] + E_{p_{data}}[ {\partial \over \partial m} U_m(x))] \\
{\partial \over \partial f} L_1&= - {\partial \over \partial f} \int_x q_f(x) U_m(x) \\&= -E_{q_f(x)}[U_m(x){\partial \over \partial f} \log q_f(x)] \\
{\partial \over \partial m} L_2 &= E_{q_f(x)} [ {\partial \over \partial m} U_m(x)] \\
{\partial \over \partial f} L_2 &= E_{q_f(x)}[U_m(x){\partial \over \partial f} \log q_f(x)] - E_{q_f(x)}[(1+\log q_f(x)){\partial\over\partial f} \log q_f(x)]\\
&= E_{q_f(x)}[(U_m(x) - \log q_f(x) - 1){\partial \over \partial f} \log q_f(x)] \\
&= E_{q_f(x)}[(U_m(x) - \log q_f(x) - 1){\partial \over \partial f} \log \sum_z q_w(z) q_r(x|z)] \\
&= E_{q_f(x)}[(U_m(x) - \log q_f(x) - 1)E_{q_{f}(z|x)}[{\partial \over \partial f} \log q_w(z) q_r(x|z)]] \\
\end{align}\end{split}\]
如果我把原始L里面的采样换成,按照1:1的比例组合这两个损失函数,会发现
m的更新不再需要关注自身的采样了,只要在 \(f\) 上采样并且更新熵就可以了.但这样似乎和minmax思想是冲突的,而且m和f之间不再有任何的耦合关系. 这说明这个约等于号可能没啥保障.
\[\begin{split}\begin{align}
L_4(m,f) &= L_1(m,f) + L_2(m,f) \\
&= E_{p_{data}}[ U_m(x))] - E_{q_f(x)}[\log q_f(x)]
\end{align}\end{split}\]
即便考虑反向KL,我们发现这样也没法绕过在 \(q_m(x)\) 上的采样.
\[\begin{split}\begin{align}
L_5(m,f) &= -D_{KL}(q_m(x)||q_f(x)) \\
&= E_{q_m(x)}[\log q_f(x)] - E_{q_m(x)}[\log q_m(x)]
\end{align}\end{split}\]