833 字
4 分钟
重参数化技巧
2024-10-11

https://kexue.fm/archives/6705

引入#

对于以下形式的目标损失函数

Lθ=Ep(z)[fθ(z)]L_{\theta}=\mathbb{E}_{p(z)}[f_{\theta}(z)]

为了最小化这个损失,我们需要从 p 中采样,但是直接采样丢失了θ\theta的信息,从而无法更新θ\theta. 在连续问题或者 z 的取值空间很大的离散问题中,我们不可能遍历所有的 z,因此需要采样.如果 z 的分布与我们需要的梯度的参数无关,则

θEpθ(z)[fθ(z)]=θ[zpθ(z)fθ(z)dz]=zp(z)[θfθ(z)]dz=Ep(z)[θfθ(z)]\begin{align} \nabla_{\theta}\mathbb{E}_{p_{\theta}(z)}[f_{\theta}(z)]=\nabla_{\theta}\left[ \int_{z}p_{\theta}(z)f_{\theta}(z)dz \right] \\ =\int_{z}p(z)[\nabla_{\theta}f_{\theta}(z)]dz \\ =\mathbb{E}_{p(z)}[\nabla_{\theta}f_{\theta}(z)] \end{align}

但是如果 pz 和θ\theta有关,那么变成

θEpθ(z)[fθ(z)]=θ[zpθ(z)fθ(z)dz]=zθ[pθ(z)fθ(z)dz]=zfθ(z)θpθ(z)dz+zpθ(z)θfθ(z)dz\begin{align} \nabla_{\theta}\mathbb{E}_{p_{\theta}(z)}[f_{\theta}(z)]=\nabla_{\theta}\left[ \int_{z}p_{\theta}(z)f_{\theta}(z)dz \right] \\ =\int_{z}\nabla_{\theta}[p_{\theta}(z)f_{\theta}(z)dz] \\ =\int_{z}f_{\theta}(z)\nabla_{\theta}p_{\theta}(z)dz+\int_{z}p_{\theta}(z)\nabla_{\theta}f_{\theta}(z)dz \end{align}

由于我们需要计算分布 p 的梯度,第一项无法变成期望的形式,也无法采样

Reparameterization#

考虑连续情况

Lθ=zpθ(z)f(z)dzL_{\theta}=\int_{z}p_{\theta}(z)f(z)dz

我们需要采样的同时保留梯度,为此我们考虑从无参分布 q 中采样,然后通过某种变换生成 z:

ϵq(ϵ)z=gθ(ϵ)\begin{align} \epsilon \sim q(\epsilon) \\ z=g_{\theta}(\epsilon) \end{align}

此时式子变成:

Lθ=Eϵq(ϵ)[f(gθ(ϵ))]L_{\theta}=\mathbb{E}_{\epsilon \sim q(\epsilon)}[f(g_{\theta}(\epsilon))]

此时我们把随机采样和梯度传播解耦了,可以直接反向传播 loss.如何理解呢?假设有一个 p,我们从中采样了一个 2,我们完全看不出 2 和参数有什么关系,但如果我们从标准正态分布中采样一个 0.2,然后计算0.2σ+μ0.2\sigma+\mu,我们就知道和θ\theta的关系了.

Gumbel-softmax#

现在是离散情况

Lθ=pθ(y)f(y)L_{\theta}=\sum p_{\theta}(y)f(y)

也就是此时pθ(y)p_{\theta}(y)是一个 k 分类模型,

py=softmax(o1,o2,o3)py=softmax(o_{1},o_{2},\dots o_{3})

理论上我们可以直接求 loss,如果空间巨大,比如 y 是一个一百维度向量,每个维度 2 类取值,我们还是得采样.和上面一样,我们考虑如何分离随机采样.如果能够采样若干个点就能得到(6)的有效估计,并且还不损失梯度信息,那自然是最好了.为此我们引入了 Gumbel Max,他提供了一种从类别分布中采样的方法.

argmax(logpilog(logϵi))i=1k,ϵiU[0,1]argmax(\log p_{i}-\log(-\log\epsilon_{i}))^k_{i=1},\epsilon_{i} \sim U[0,1]

也就是先算出各个概率的对数,然后从均匀分布[0,1]中采样 k 个数,然后加起来把最大值的类别抽出来就好了.这样的过程精确等价于依据概率抽样一个类别.也就是输出 i 的概率就是 pi.现在的随机性已经到均匀分布上了,不带参数,所以这是一个重参数化过程. 但是我们不希望丢失梯度信息,这个函数做不到,因为 arggmax 不可导.因此我们需要再一部近似.首先注意再神经网络中,处理离散输入的方法是转为 one hot,因此argmax 本质上也是 onehot(argmax),表示模型最终选了哪个类别,然后我们对这个求光滑近似,就是 softmax.因此我们就得到了他的光滑版本.Gumbel softmax

softmax(logpilog(logϵi)/τ)i=1k,ϵiU[0,1]softmax(\log p_{i}-\log(-\log\epsilon_{i})/\tau)^k_{i=1},\epsilon_{i} \sim U[0,1]

其中 tau 是退火参数,越小越接近 onehot,但是梯度消失越严重

证明 Gumbel max 对于logp1log(logϵ1)\log p_{1}-\log(-\log\epsilon_{1}),他大于其他所有的,所以 我们化简logp1log(logϵ1)>logpilog(logϵi)\log p_{1}-\log(-\log\epsilon_{1})>\log p_{i}-\log(-\log\epsilon_{i}) 得到ϵi<ϵip2/p11\epsilon_{i}<\epsilon_{i}^{p_{2}/p_{1}}\leq 1,而由于ϵiU[0,1]\epsilon_{i}\sim U[0,1],所以成立的概率为ϵp2/p1\epsilon^{p_{2}/p_{1}},如果需要同时满足,那么就可以得到概率为ϵ11/p11\epsilon_{1}^{1/p_{1}-1}.对所有的ϵ1\epsilon_{1}求平均就是01ϵ11/p11dϵ1=p1\int^{1}_{0}\epsilon_{1}^{1/p_{1}-1}d\epsilon_{1}=p_{1},这就是类别 1 出现的概率.至此我们完成了采样过程的证明

重参数化技巧
https://fuwari.vercel.app/posts/数学/重参数化技巧/
作者
FlyingWhite
发布于
2024-10-11
许可协议
CC BY-NC-SA 4.0