DL之LSTM:tf.contrib.rnn.BasicLSTMCell(rnn_unit)函数的解读


虚幻有微笑
虚幻有微笑 2022-09-19 13:05:47 51283
分类专栏: 资讯

DL之LSTM:tf.contrib.rnn.BasicLSTMCell(rnn_unit)函数的解读

目录

tf.contrib.rnn.BasicLSTMCell(rnn_unit)函数的解读

函数功能解读

函数代码实现


tf.contrib.rnn.BasicLSTMCell(rnn_unit)函数的解读

函数功能解读

  """Basic LSTM recurrent network cell.

  The implementation is based on: http://arxiv.org/abs/1409.2329.

  We add forget_bias (default: 1) to the biases of the forget gate in order to reduce the scale of forgetting in the beginning of the training.

  It does not allow cell clipping, a projection layer, and does not use peep-hole connections: it is the basic baseline.  For advanced models, please use the full @{tf.nn.rnn_cell.LSTMCell}
  that follows.

  """

  def __init__(self,
               num_units,
               forget_bias=1.0,
               state_is_tuple=True,
               activation=None,
               reuse=None,
               name=None,
               dtype=None):
    """Initialize the basic LSTM cell.

基本LSTM递归网络单元。

实现基于:http://arxiv.org/abs/1409.2329。

我们在遗忘门的偏见中加入了遗忘偏见(默认值:1),以减少训练开始时的遗忘程度。

它不允许细胞剪切(一个投影层),也不使用窥孔连接:它是基本的基线。对于高级模型,请使用完整的@{tf.n .rnn_cell. lstmcell}遵循。

    Args:
      num_units: int, The number of units in the LSTM cell.
      forget_bias: float, The bias added to forget gates (see above).
        Must set to `0.0` manually when restoring from CudnnLSTM-trained checkpoints.
      state_is_tuple: If True, accepted and returned states are 2-tuples of the `c_state` and `m_state`.  If False, they are concatenated along the column axis.  The latter behavior will soon be deprecated.
      activation: Activation function of the inner states.  Default: `tanh`.
      reuse: (optional) Python boolean describing whether to reuse variables in an existing scope.  If not `True`, and the existing scope already has the given variables, an error is raised.
      name: String, the name of the layer. Layers with the same name will share weights, but to avoid mistakes we require reuse=True in such cases.
      dtype: Default dtype of the layer (default of `None` means use the type of the first input). Required when `build` is called before `call`.

      When restoring from CudnnLSTM-trained checkpoints, must use `CudnnCompatibleLSTMCell` instead.
    """

参数:
num_units: int类型, LSTM单元中的单元数。
forget_bias: float类型,偏见添加到忘记门(见上面)。
从cudnnlstm训练的检查点恢复时,必须手动设置为“0.0”。
state_is_tuple: 如果为真,则接受状态和返回状态是' c_state '和' m_state '的二元组。如果为假,则沿着列轴连接它们。后一种行为很快就会被摒弃。
activation: 内部状态的激活功能。默认值tanh激活函数
reuse: (可选)Python布尔值,描述是否在现有范围内重用变量。如果不是“True”,并且现有范围已经有给定的变量,则会引发错误。
name:字符串,层的名称。具有相同名称的层将共享权重,但是为了避免错误,我们需要在这种情况下重用=True。
dtype:该层的默认dtype(默认为‘None’意味着使用第一个输入的类型)。当' build '在' call '之前被调用时是必需的。

从经过cudnnlstm训练的检查点恢复时,必须使用“CudnnCompatibleLSTMCell”。
”“”

函数代码实现

  1. -meta">@tf_export("nn.rnn_cell.BasicLSTMCell")
  2. class BasicLSTMCell(-title class_ inherited__">LayerRNNCell):
  3. """Basic LSTM recurrent network cell.
  4. The implementation is based on: http://arxiv.org/abs/1409.2329.
  5. We add forget_bias (default: 1) to the biases of the forget gate in order to
  6. reduce the scale of forgetting in the beginning of the training.
  7. It does not allow cell clipping, a projection layer, and does not
  8. use peep-hole connections: it is the basic baseline.
  9. For advanced models, please use the full @{tf.nn.rnn_cell.LSTMCell}
  10. that follows.
  11. """
  12. def __init__(self,
  13. num_units,
  14. forget_bias=1.0,
  15. state_is_tuple=True,
  16. activation=None,
  17. reuse=None,
  18. name=None,
  19. dtype=None):
  20. """Initialize the basic LSTM cell.
  21. Args:
  22. num_units: int, The number of units in the LSTM cell.
  23. forget_bias: float, The bias added to forget gates (see above).
  24. Must set to `0.0` manually when restoring from CudnnLSTM-trained
  25. checkpoints.
  26. state_is_tuple: If True, accepted and returned states are 2-tuples of
  27. the `c_state` and `m_state`. If False, they are concatenated
  28. along the column axis. The latter behavior will soon be deprecated.
  29. activation: Activation function of the inner states. Default: `tanh`.
  30. reuse: (optional) Python boolean describing whether to reuse variables
  31. in an existing scope. If not `True`, and the existing scope already has
  32. the given variables, an error is raised.
  33. name: String, the name of the layer. Layers with the same name will
  34. share weights, but to avoid mistakes we require reuse=True in such
  35. cases.
  36. dtype: Default dtype of the layer (default of `None` means use the type
  37. of the first input). Required when `build` is called before `call`.
  38. When restoring from CudnnLSTM-trained checkpoints, must use
  39. `CudnnCompatibleLSTMCell` instead.
  40. """
  41. super(BasicLSTMCell, self).__init__(_reuse=reuse, name=name, dtype=dtype)
  42. if not state_is_tuple:
  43. logging.warn("%s: Using a concatenated state is slower and will soon be "
  44. "deprecated. Use state_is_tuple=True.", self)
  45. Inputs must be 2-dimensional.
  46. self.input_spec = base_layer.InputSpec(ndim=2)
  47. self._num_units = num_units
  48. self._forget_bias = forget_bias
  49. self._state_is_tuple = state_is_tuple
  50. self._activation = activation or math_ops.tanh
  51. -meta"> @property
  52. def state_size(self):
  53. return (LSTMStateTuple(self._num_units, self._num_units)
  54. if self._state_is_tuple else 2 * self._num_units)
  55. -meta"> @property
  56. def output_size(self):
  57. return self._num_units
  58. def build(self, inputs_shape):
  59. if inputs_shape[1].value is None:
  60. raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
  61. % inputs_shape)
  62. input_depth = inputs_shape[1].value
  63. h_depth = self._num_units
  64. self._kernel = self.add_variable(
  65. _WEIGHTS_VARIABLE_NAME,
  66. shape=[input_depth + h_depth, 4 * self._num_units])
  67. self._bias = self.add_variable(
  68. _BIAS_VARIABLE_NAME,
  69. shape=[4 * self._num_units],
  70. initializer=init_ops.zeros_initializer(dtype=self.dtype))
  71. self.built = True
  72. def call(self, inputs, state):
  73. """Long short-term memory cell (LSTM).
  74. Args:
  75. inputs: `2-D` tensor with shape `[batch_size, input_size]`.
  76. state: An `LSTMStateTuple` of state tensors, each shaped
  77. `[batch_size, num_units]`, if `state_is_tuple` has been set to
  78. `True`. Otherwise, a `Tensor` shaped
  79. `[batch_size, 2 * num_units]`.
  80. Returns:
  81. A pair containing the new hidden state, and the new state (either a
  82. `LSTMStateTuple` or a concatenated state, depending on
  83. `state_is_tuple`).
  84. """
  85. sigmoid = math_ops.sigmoid
  86. one = constant_op.constant(1, dtype=dtypes.int32)
  87. Parameters of gates are concatenated into one multiply for efficiency.
  88. if self._state_is_tuple:
  89. c, h = state
  90. else:
  91. c, h = array_ops.split(value=state, num_or_size_splits=2, axis=one)
  92. gate_inputs = math_ops.matmul(
  93. array_ops.concat([inputs, h], 1), self._kernel)
  94. gate_inputs = nn_ops.bias_add(gate_inputs, self._bias)
  95. i = input_gate, j = new_input, f = forget_gate, o = output_gate
  96. i, j, f, o = array_ops.split(
  97. value=gate_inputs, num_or_size_splits=4, axis=one)
  98. forget_bias_tensor = constant_op.constant(self._forget_bias, dtype=f.dtype)
  99. Note that using `add` and `multiply` instead of `+` and `*` gives a
  100. performance improvement. So using those at the cost of readability.
  101. add = math_ops.add
  102. multiply = math_ops.multiply
  103. new_c = add(multiply(c, sigmoid(add(f, forget_bias_tensor))),
  104. multiply(sigmoid(i), self._activation(j)))
  105. new_h = multiply(self._activation(new_c), sigmoid(o))
  106. if self._state_is_tuple:
  107. new_state = LSTMStateTuple(new_c, new_h)
  108. else:
  109. new_state = array_ops.concat([new_c, new_h], 1)
  110. return new_h, new_state

网站声明:如果转载,请联系本站管理员。否则一切后果自行承担。

本文链接:https://www.xckfsq.com/news/show.html?id=2553
赞同 0
评论 0 条
虚幻有微笑L0
粉丝 0 发表 9 + 关注 私信
上周热门
如何使用 StarRocks 管理和优化数据湖中的数据?  2959
【软件正版化】软件正版化工作要点  2878
统信UOS试玩黑神话:悟空  2843
信刻光盘安全隔离与信息交换系统  2737
镜舟科技与中启乘数科技达成战略合作,共筑数据服务新生态  1270
grub引导程序无法找到指定设备和分区  1235
华为全联接大会2024丨软通动力分论坛精彩议程抢先看!  165
点击报名 | 京东2025校招进校行程预告  164
2024海洋能源产业融合发展论坛暨博览会同期活动-海洋能源与数字化智能化论坛成功举办  163
华为纯血鸿蒙正式版9月底见!但Mate 70的内情还得接着挖...  159
本周热议
我的信创开放社区兼职赚钱历程 40
今天你签到了吗? 27
信创开放社区邀请他人注册的具体步骤如下 15
如何玩转信创开放社区—从小白进阶到专家 15
方德桌面操作系统 14
我有15积分有什么用? 13
用抖音玩法闯信创开放社区——用平台宣传企业产品服务 13
如何让你先人一步获得悬赏问题信息?(创作者必看) 12
2024中国信创产业发展大会暨中国信息科技创新与应用博览会 9
中央国家机关政府采购中心:应当将CPU、操作系统符合安全可靠测评要求纳入采购需求 8

加入交流群

请使用微信扫一扫!