常见Cell和函数
- tf.nn.rnn_cell.BasicRNNCell: 最基本的RNN cell.
- tf.nn.rnn_cell.LSTMCell: LSTM cell
- tf.nn.rnn_cell.LSTMStateTuple: tupled LSTM cell
- tf.nn.rnn_cell.MultiRNNCell: 多层Cell
- tf.nn.rnn_cell.DropoutCellWrapper: 给Cell加上dropout
- tf.nn.dynamic_rnn: 动态rnn
- tf.nn.static_rnn: 静态rnn
BasicRNNCell
API
1 | __init__( |
示例
1 | myrnn = rnn.BasicRNNCell(rnn_size,activation=tf.nn.relu) |
其他
TF 2.0将会弃用,等价于tf.keras.layers.SimpleRNNCell()
LSTMCell
API
1 | __init__( |
示例
1 | lstm = rnn.BasicLSTMCell(lstm_size, forget_bias=1, state_is_tuple=True) |
其他
TF 2.0将会弃用,等价于tf.keras.layers.LSTMCell
LSTMStateTuple
和LSTMCell一样,只不过state用的是tuple。
其他
TF 2.0将会弃用,等价于tf.keras.layers.LSTMCell
MultiRNNCell
这个类可以实现多层RNN。
API
1 | __init__( |
示例
代码1
1 | num_units = [128, 64] |
代码2
1 | lstm_cell = rnn.BasicLSTMCell(lstm_size, forget_bias=1, state_is_tuple=True) |
其他
TF 2.0将会弃用,等价于tf.keras.layers.StackedRNNCells
DropoutCellWrapper
API
1 | __init__( |
示例
1 | lstm_cell = rnn.BasicLSTMCell(lstm_size, forget_bias=1, state_is_tuple=True) |
其他
static_rnn
API
1 | tf.nn.static_rnn( |
示例
1 | myrnn = tf.nn.rnn_cell.BasicRNNCell(rnn_size,activation=tf.nn.relu) |
dynamic rnn
API
1 | tf.nn.dynamic_rnn( |
示例
1 | # 例子1.创建一个BasicRNNCell |
static_rnn vs dynamic_rnn
tf.keras.layers.RNN(cell)
在tensorflow 2.0中,上述两个API都会被弃用,使用新的keras.layers.RNN(cell)
tf.nn.rnn_cell
该模块提供了许多RNN cell类和rnn函数。
类
- class BasicRNNCell: 最基本的RNN cell.
- class BasicLSTMCell: 弃用了,使用tf.nn.rnn_cell.LSTMCell代替,就是下面那个
- class LSTMCell: LSTM cell
- class LSTMStateTuple: tupled LSTM cell
- class GRUCell: GRU cell (引用文献 http://arxiv.org/abs/1406.1078).
- class RNNCell: 表示一个RNN cell的抽象对象
- class MultiRNNCell: 由很多个简单cells顺序组合成的RNN cell
- class DeviceWrapper: 保证一个RNNCell在一个特定的device运行的op.
- class DropoutWrapper: 添加droput到给定cell的的inputs和outputs的op.
- class ResidualWrapper: 确保cell的输入被添加到输出的RNNCell warpper。
函数
- static_rnn(…) # 未来将被弃用,和tf.contrib.rnn.static_rnn是一样的。
- dynamic_rnn(…) # 未来将被弃用
- static_bidirectional_rnn(…) # 未来将被弃用
- bidirectional_dynamic_rnn(…) # 未来将被弃用
- raw_rnn(…)
tf.contrib.rnn
该模块提供了RNN和Attention RNN的类和函数op。
类
- class RNNCell: # 抽象类,所有Cell都要继承该类。所有的Warpper都要直接继承该Cell。
- class LayerRNNCell: # 所有的下列定义的Cell都要使用继承该Cell,该Cell继承RNNCell,所以所有下列Cell都间接继承RNNCell。
- class BasicRNNCell:
- class BasicLSTMCell: # 将被弃用,使用下面的LSTMCell。
- class LSTMCell:
- class LSTMStateTuple:
- class GRUCell:
- class MultiRNNCell:
- class ConvLSTMCell:
- class GLSTMCell:
- class Conv1DLSTMCell:
- class Conv2DLSTMCell:
- class Conv3DLSTMCell:
- class BidirectionalGridLSTMCell:
- class AttentionCellWrapper:
- class CompiledWrapper:
- class CoupledInputForgetGateLSTMCell:
- class DeviceWrapper:
- class DropoutWrapper:
- class EmbeddingWrapper:
- class FusedRNNCell:
- class FusedRNNCellAdaptor:
- class GRUBlockCell:
- class GRUBlockCellV2:
- class GridLSTMCell:
- class HighwayWrapper:
- class IndRNNCell:
- class IndyGRUCell:
- class IndyLSTMCell:
- class InputProjectionWrapper:
- class IntersectionRNNCell:
- class LSTMBlockCell:
- class LSTMBlockFusedCell:
- class LSTMBlockWrapper:
- class LayerNormBasicLSTMCell:
- class NASCell:
- class OutputProjectionWrapper:
- class PhasedLSTMCell:
- class ResidualWrapper:
- class SRUCell:
- class TimeFreqLSTMCell:
- class TimeReversedFusedRNN:
- class UGRNNCell:
函数
- static_rnn(…) # 将被弃用,和tf.nn.static_rnn是一样的
- static_bidirectional_rnn(…) # 将被弃用
- best_effort_input_batch_size(…)
- stack_bidirectional_dynamic_rnn(…)
- stack_bidirectional_rnn(…)
- static_state_saving_rnn(…)
- transpose_batch_time(…)
tf.contrib.rnn vs tf.nn.rnn_cell
事实上,这两个模块中都定义了许多RNN cell,contrib定义的是测试性的代码,而nn.rnn_cell是contrib中经过测试后的代码。
contrib中的代码会经常修改,而nn中的代码比较稳定。
contrib中的cell类型比较多,而nn中的比较少。
contrib和nn中有重复的cell,基本上nn中有的contrib中都有。
参考文献
1.https://www.tensorflow.org/api_docs/python/tf/nn/rnn_cell/RNNCell
2.https://www.tensorflow.org/api_docs/python/tf/nn/rnn_cell/BasicRNNCell
3.https://www.tensorflow.org/api_docs/python/tf/nn/rnn_cell/LSTMCell
4.https://www.tensorflow.org/api_docs/python/tf/nn/rnn_cell/MultiRNNCell
5.https://www.tensorflow.org/api_docs/python/tf/nn/rnn_cell/LSTMStateTuple
6.https://www.tensorflow.org/api_docs/python/tf/nn/rnn_cell/DropoutWrapper
7.https://www.tensorflow.org/api_docs/python/tf/nn/static_rnn
8.https://www.tensorflow.org/api_docs/python/tf/nn/dynamic_rnn
9.https://www.tensorflow.org/api_docs/python/tf/contrib/rnn
10.https://www.tensorflow.org/api_docs/python/tf/nn/rnn_cell
11.https://www.cnblogs.com/wuzhitj/p/6297992.html
12.https://stackoverflow.com/questions/48001759/what-is-right-batch-normalization-function-in-tensorflow