AI智能
改变未来

Policy gradient 解决Sequence generation使用GAN时梯度无法更新的问题


GAN

z∈RT×dz\\in \\mathbb{R}^{T\\times d}z∈RT×d
outputg=Gθ(z)∈RT×V,whereVisthesizeofvocabulary.output_g=G_{\\theta}(z)\\in\\mathbb{R}^{T\\times V},\\ \\text{where V is the size of vocabulary.}outputg​=Gθ​(z)∈RT×V,whereVisthesizeofvocabulary.
Y=arg max⁡(outputg)∈RT×1→thisoperationwillpreventgeneratorfrombeingupdated.Y=\\argmax(output_{g})\\in \\mathbb{R}^{T\\times1}\\rightarrow\\textbf{\\text{this operation will prevent generator from being updated.}}Y=argmax(outputg​)∈RT×1→thisoperationwillpreventgeneratorfrombeingupdated.
mappingYtodensevectorasDiscriminatorinputbyusingembedding:\\text{mapping Y to dense vector as Discriminator input by using embedding:}mappingYtodensevectorasDiscriminatorinputbyusingembedding:
inputd=Embedding(Y)∈RT×dminput_d=Embedding(Y)\\in\\mathbb{R}^{T\\times d_m}inputd​=Embedding(Y)∈RT×dm​
outputd=Dθ(inputd)∈Routput_d=D_{\\theta}(input_d)\\in\\mathbb{R}outputd​=Dθ​(inputd​)∈R
L=Ex∼pdata(x)[log⁡D(x)]+Ez∼pz(z)[log⁡(1−D(G(z)))](1)\\mathcal{L}=\\mathbb{E}_{x\\sim p_{data}(x)}[\\log D(x)]+\\mathbb{E}_{z\\sim p_z(z)}[\\log(1-D(G(z)))]\\ (1)L=Ex∼pdata​(x)​[logD(x)]+Ez∼pz​(z)​[log(1−D(G(z)))](1)

GAN在计算机视觉特别是图像生成方面应用很广泛,但是在文本生成领域应用有比较大的困难。主要是因为GAN是用于生成真实连续的数据(例如图像), 而文本生成是生成离散的数据(对应于词典中的字符)。具体的说,在文本生成中,generator中有argmax操作,该操作是不可导的,在反向传播时,梯度更新会在该操作处停止,从而使Generator无法更新。

Policy Gradient

从以上分析可以得出,Generator无法更新主要是存在不可求导操作(arg max⁡\\argmaxargmax)引起的。解决这个问题可以从两个方面思考,一个是用一个可导函数(神经网络)逼近arg max⁡\\argmaxargmax操作(这个是我自己猜想的,并没有找到参考文献,不一定可行),另一个是在更新Generator时不使用该操作。policy gradient就是使用的第二种方法,这里的policy指的就是Generator。
Policy gradient是强化学习(Reinforcement learning, RL)中的一种说法,具体的可以参考网上强化学习的资料。RL主要有几个要素agent、environment、action、reward、state。将RL运用到文本生成领域,可以将agent看成generator, environment看成Discriminator,action为将要生成的字符(token),reward为Discriminator给出的打分(生成的句子被判断成为真实句子的概率),state为已经生成的token。
Discriminator的更新与上述方法并无区别,Generator更新主要区别在于loss并非从discriminator端传入(式1),而是将V(s0)V(s_0)V(s0​)(状态s0s_0s0​的价值函数)作为loss,目标则是最大化价值函数。具体式子如下:
J(θ)=E[RT∣s0,θ]=∑y1∈YGθ(y1∣s0)QGθDϕ(s0,y1)J(\\theta)=\\mathbb{E}[R_T|s_0,\\theta]=\\sum_{y_1\\in\\mathcal{Y}}G_{\\theta}(y_1|s_0)Q_{G_{\\theta}}^{D_{\\phi}}(s_0,y_1)J(θ)=E[RT​∣s0​,θ]=y1​∈Y∑​Gθ​(y1​∣s0​)QGθ​Dϕ​​(s0​,y1​)
该式出自论文seqGAN。按照我的理解,s0s_0s0​为sequence开始标识,是一个特殊字符,例如<BOS>。y1y_1y1​为即将生成的下一个token。Y\\mathcal{Y}Y为字典。Gθ(y1∣s0)G_{\\theta}(y_1|s_0)Gθ​(y1​∣s0​)表示当前状态为s0s_0s0​,在策略GθG_{\\theta}Gθ​下,下一个action为y1y_1y1​的概率。QGθDϕ(s0,y1)Q_{G_{\\theta}}^{D_{\\phi}}(s_0,y_1)QGθ​Dϕ​​(s0​,y1​)为action-value。文章采用Monte Carlo采样方法计算该值:
QGθDϕ(s=Y1:t−1,a=yt)=1N∑n=1NDϕ(Y1:Tn),Y1:Tn∈MCGβ(Y1:T;N)Q_{G_{\\theta}}^{D_{\\phi}}(s=Y_{1:t-1},a=y_t)=\\dfrac{1}{N}\\sum_{n=1}^ND_{\\phi}(Y_{1:T}^n), Y_{1:T}^n \\in MC^{G_{\\beta}}(Y_{1:T};N)QGθ​Dϕ​​(s=Y1:t−1​,a=yt​)=N1​n=1∑N​Dϕ​(Y1:Tn​),Y1:Tn​∈MCGβ​(Y1:T​;N)
上式计算的是state为Y1:t−1Y_{1:t-1}Y1:t−1​时,action为yty_tyt​的return。这样做的好处在于,不仅可以获得整个序列的reward,还可以获得中间任意位置生成token的reward。也就是说,不但考虑了长期受益,也考虑了短期收益。
我们说过policy gradient可以有效的解决梯度更新无法传递到generator上的问题,观察上式,跟新generator的目标函数J(θ)J({\\theta})J(θ)并没有采用任何不可求导操作,整个式子可以看成是state-action pair(s,y)(s,y)(s,y)能获得的Reward的期望(GθG_{\\theta}Gθ​返回的是概率值)。
上述讲的很粗糙,要真正的理解其中细节,需要仔细的看看论文,这篇论文花了我很长时间。先是恶补了一下强化学习的相关资料,然后再啃的论文,现在也不敢说百分百读懂了。欢迎一起交流!

赞(0) 打赏
未经允许不得转载:爱站程序员基地 » Policy gradient 解决Sequence generation使用GAN时梯度无法更新的问题