文章目录
- 一 RNN概述
- 1.1 BP算法,CNN之后, 为什么还有RNN?
- 1.2 什么是RNN?
- 1.3 RNN的主要应用领域有哪些呢?
- 2.1 RNN模型结构
- 2.2 RNN的反向传播
- 3.1 LSTM算法(Long Short Term Memory, 长短期记忆网络 )
- 3.2 GRU算法
一 RNN概述
前面我们叙述了BP算法, CNN算法, 那么为什么还会有RNN呢?? 什么是RNN, 它到底有什么不同之处? RNN的主要应用领域有哪些呢?这些都是要讨论的问题.
1.1 BP算法,CNN之后, 为什么还有RNN?
细想BP算法,CNN(卷积神经网络)我们会发现, 他们的输出都是只考虑前一个输入的影响而不考虑其它时刻输入的影响, 比如简单的猫,狗,手写数字等单个物体的识别具有较好的效果. 但是, 对于一些与时间先后有关的, 比如视频的下一时刻的预测,文档前后文内容的预测等, 这些算法的表现就不尽如人意了.因此, RNN就应运而生了.
1.2 什么是RNN?
RNN是一种特殊的神经网络结构, 它是根据\”人的认知是基于过往的经验和记忆\”这一观点提出的. 它与DNN,CNN不同的是: 它不仅考虑前一时刻的输入,而且赋予了网络对前面的内容的一种’记忆’功能.
RNN之所以称为循环神经网路,即一个序列当前的输出与前面的输出也有关。具体的表现形式为网络会对前面的信息进行记忆并应用于当前输出的计算中,即隐藏层之间的节点不再无连接而是有连接的,并且隐藏层的输入不仅包括输入层的输出还包括上一时刻隐藏层的输出。
1.3 RNN的主要应用领域有哪些呢?
RNN的应用领域有很多, 可以说只要考虑时间先后顺序的问题都可以使用RNN来解决.这里主要说一下几个常见的应用领域:
① 自然语言处理(NLP): 主要有视频处理, 文本生成, 语言模型, 图像处理
② 机器翻译, 机器写小说
③ 语音识别
④ 图像描述生成
⑤ 文本相似度计算
⑥ 音乐推荐、网易考拉商品推荐、Youtube视频推荐等新的应用领域.
二 RNN(循环神经网络)
2.1 RNN模型结构
前面我们说了RNN具有时间\”记忆\”的功能, 那么它是怎么实现所谓的\”记忆\”的呢?
如图1所示, 我们可以看到RNN层级结构较之于CNN来说比较简单, 它主要有输入层,Hidden Layer, 输出层组成.
并且会发现在Hidden Layer 有一个箭头表示数据的循环更新, 这个就是实现时间记忆功能的方法.
如果到这里你还是没有搞懂RNN到底是什么意思,那么请继续往下看!
如图2所示为Hidden Layer的层级展开图. t-1, t, t+1表示时间序列. X表示输入的样本. St表示样本在时间t处的的记忆,St = f(WSt-1 +UXt). W表示输入的权重, U表示此刻输入的样本的权重, V表示输出的样本权重.
在t =1时刻, 一般初始化输入S0=0, 随机初始化W,U,V, 进行下面的公式计算:
其中,f和g均为激活函数. 其中f可以是tanh,relu,sigmoid等激活函数,g通常是softmax也可以是其他。
时间就向前推进,此时的状态s1作为时刻1的记忆状态将参与下一个时刻的预测活动,也就是:
以此类推, 可以得到最终的输出值为:
注意: 1. 这里的W,U,V在每个时刻都是相等的(权重共享).
2. 隐藏状态可以理解为: S=f(现有的输入+过去记忆总结)
2.2 RNN的反向传播
前面我们介绍了RNN的前向传播的方式, 那么RNN的权重参数W,U,V都是怎么更新的呢?
每一次的输出值Ot都会产生一个误差值Et, 则总的误差可以表示为:.
则损失函数可以使用交叉熵损失函数也可以使用平方误差损失函数.
由于每一步的输出不仅仅依赖当前步的网络,并且还需要前若干步网络的状态,那么这种BP改版的算法叫做Backpropagation Through Time(BPTT) , 也就是将输出端的误差值反向传递,运用梯度下降法进行更新.(不熟悉BP的可以参考这里)
也就是要求参数的梯度:
首先我们求解W的更新方法, 由前面的W的更新可以看出它是每个时刻的偏差的偏导数之和.
在这里我们以 t = 3时刻为例, 根据链式求导法则可以得到t = 3时刻的偏导数为:
此时, 根据公式我们会发现, S3除了和W有关之外, 还和前一时刻S2有关.
对于S3直接展开得到下面的式子:
对于S2直接展开得到下面的式子:
对于S1直接展开得到下面的式子:
将上述三个式子合并得到:
这样就得到了公式:
这里要说明的是:
表示的是S3对W直接求导, 不考虑S2的影响.(也就是例如y = f(x)*g(x)对x求导一样)
其次是对U的更新方法. 由于参数U求解和W求解类似,这里就不在赘述了,最终得到的具体的公式如下:
最后,给出V的更新公式(V只和输出O有关):
三 RNN的一些改进算法
前面我们介绍了RNN的算法, 它处理时间序列的问题的效果很好, 但是仍然存在着一些问题, 其中较为严重的是容易出现梯度消失或者梯度爆炸的问题(BP算法和长时间依赖造成的). 注意: 这里的梯度消失和BP的不一样,这里主要指由于时间过长而造成记忆值较小的现象.
因此, 就出现了一系列的改进的算法, 这里介绍主要的两种算法: LSTM 和 GRU.
LSTM 和 GRU对于梯度消失或者梯度爆炸的问题处理方法主要是:
对于梯度消失: 由于它们都有特殊的方式存储”记忆”,那么以前梯度比较大的”记忆”不会像简单的RNN一样马上被抹除,因此可以一定程度上克服梯度消失问题。
对于梯度爆炸:用来克服梯度爆炸的问题就是gradient clipping,也就是当你计算的梯度超过阈值c或者小于阈值-c的时候,便把此时的梯度设置成c或-c。
3.1 LSTM算法(Long Short Term Memory, 长短期记忆网络 )
重要的目前使用最多的时间序列算法
图为LSTM算法的结构图.
和RNN不同的是: RNN中,就是个简单的线性求和的过程. 而LSTM可以通过“门”结构来去除或者增加“细胞状态”的信息,实现了对重要内容的保留和对不重要内容的去除. 通过Sigmoid层输出一个0到1之间的概率值,描述每个部分有多少量可以通过,0表示“不允许任务变量通过”,1表示“运行所有变量通过 ”.
用于遗忘的门叫做\”遗忘门\”, 用于信息增加的叫做\”信息增加门\”,最后是用于输出的\”输出门\”. 这里就不展开介绍了.
此外,LSTM算法的还有一些变种.
如图4所示, 它增加“peephole connections”层 , 让门层也接受细胞状态的输入.
如图所示为LSTM的另外一种变种算法.它是通过耦合忘记门和更新输入门(第一个和第二个门);也就是不再单独的考虑忘记什么、增加什么信息,而是一起进行考虑。
3.2 GRU算法
GRU是2014年提出的一种LSTM改进算法. 它将忘记门和输入门合并成为一个单一的更新门, 同时合并了数据单元状态和隐藏状态, 使得模型结构比之于LSTM更为简单.
在这里插入图片描述
其各个部分满足关系式如下:
四 基于Tensorflow的基本操作和总结
使用tensorflow的基本操作如下:
# _*_coding:utf-8_*_import tensorflow as tfimport numpy as np\'\'\'TensorFlow中的RNN的API主要包括以下两个路径:1) tf.nn.rnn_cell(主要定义RNN的几种常见的cell)2) tf.nn(RNN中的辅助操作)\'\'\'# 一 RNN中的cell# 基类(最顶级的父类): tf.nn.rnn_cell.RNNCell()# 最基础的RNN的实现: tf.nn.rnn_cell.BasicRNNCell()# 简单的LSTM cell实现: tf.nn.rnn_cell.BasicLSTMCell()# 最常用的LSTM实现: tf.nn.rnn_cell.LSTMCell()# RGU cell实现: tf.nn.rnn_cell.GRUCell()# 多层RNN结构网络的实现: tf.nn.rnn_cell.MultiRNNCell()# 创建cell# cell = tf.nn.rnn_cell.BasicRNNCell(num_units=128)# print(cell.state_size)# print(cell.output_size)# shape=[4, 64]表示每次输入4个样本, 每个样本有64个特征# inputs = tf.placeholder(dtype=tf.float32, shape=[4, 64])# 给定RNN的初始状态# s0 = cell.zero_state(4, tf.float32)# print(s0.get_shape())# 对于t=1时刻传入输入和state0,获取结果值# output, s1 = cell.call(inputs, s0)# print(output.get_shape())# print(s1.get_shape())# 定义LSTM celllstm_cell = tf.nn.rnn_cell.LSTMCell(num_units=128)# shape=[4, 64]表示每次输入4个样本, 每个样本有64个特征inputs = tf.placeholder(tf.float32, shape=[4, 48])# 给定初始状态s0 = lstm_cell.zero_state(4, tf.float32)# 对于t=1时刻传入输入和state0,获取结果值output, s1 = lstm_cell.call(inputs, s0)print(output.get_shape())print(s1.h.get_shape())print(s1.c.get_shape())
当然, 你可能会发现使用cell.call()每次只能调用一个得到一个状态, 如有多个状态需要多次重复调用较为麻烦, 那么我们怎么解决的呢? 可以参照后面的基于RNN的手写数字识别和单词预测的实例查找解决方法.