TF之RNN:实现利用scope.reuse_variables()告诉TF想重复利用RNN的参数的案例


犹豫迎心情
犹豫迎心情 2022-09-20 11:12:16 66849
分类专栏: 资讯

TF之RNN:实现利用scope.reuse_variables()告诉TF想重复利用RNN的参数的案例

目录

输出结果

代码设计


输出结果

后期更新……

代码设计

  1. import tensorflow as tf
  2. 22 scope (name_scope/variable_scope)
  3. from __future__ import print_function
  4. class TrainConfig:
  5. batch_size = 20
  6. time_steps = 20
  7. input_size = 10
  8. output_size = 2
  9. cell_size = 11
  10. learning_rate = 0.01
  11. class TestConfig(-title class_ inherited__">TrainConfig):
  12. time_steps = 1
  13. class RNN(-title class_ inherited__">object):
  14. def __init__(self, config):
  15. self._batch_size = config.batch_size
  16. self._time_steps = config.time_steps
  17. self._input_size = config.input_size
  18. self._output_size = config.output_size
  19. self._cell_size = config.cell_size
  20. self._lr = config.learning_rate
  21. self._built_RNN()
  22. def _built_RNN(self):
  23. with tf.variable_scope('inputs'):
  24. self._xs = tf.placeholder(tf.float32, [self._batch_size, self._time_steps, self._input_size], name='xs')
  25. self._ys = tf.placeholder(tf.float32, [self._batch_size, self._time_steps, self._output_size], name='ys')
  26. with tf.name_scope('RNN'):
  27. with tf.variable_scope('input_layer'):
  28. l_in_x = tf.reshape(self._xs, [-1, self._input_size], name='2_2D') (batch*n_step, in_size)
  29. Ws (in_size, cell_size)
  30. Wi = self._weight_variable([self._input_size, self._cell_size])
  31. print(Wi.name)
  32. bs (cell_size, )
  33. bi = self._bias_variable([self._cell_size, ])
  34. l_in_y = (batch * n_steps, cell_size)
  35. with tf.name_scope('Wx_plus_b'):
  36. l_in_y = tf.matmul(l_in_x, Wi) + bi
  37. l_in_y = tf.reshape(l_in_y, [-1, self._time_steps, self._cell_size], name='2_3D')
  38. with tf.variable_scope('cell'):
  39. cell = tf.contrib.rnn.BasicLSTMCell(self._cell_size)
  40. with tf.name_scope('initial_state'):
  41. self._cell_initial_state = cell.zero_state(self._batch_size, dtype=tf.float32)
  42. self.cell_outputs = []
  43. cell_state = self._cell_initial_state
  44. for t in range(self._time_steps):
  45. if t > 0: tf.get_variable_scope().reuse_variables()
  46. cell_output, cell_state = cell(l_in_y[:, t, :], cell_state)
  47. self.cell_outputs.append(cell_output)
  48. self._cell_final_state = cell_state
  49. with tf.variable_scope('output_layer'):
  50. cell_outputs_reshaped (BATCH*TIME_STEP, CELL_SIZE)
  51. cell_outputs_reshaped = tf.reshape(tf.concat(self.cell_outputs, 1), [-1, self._cell_size])
  52. Wo = self._weight_variable((self._cell_size, self._output_size))
  53. bo = self._bias_variable((self._output_size,))
  54. product = tf.matmul(cell_outputs_reshaped, Wo) + bo
  55. _pred shape (batch*time_step, output_size)
  56. self._pred = tf.nn.relu(product) for displacement
  57. with tf.name_scope('cost'):
  58. _pred = tf.reshape(self._pred, [self._batch_size, self._time_steps, self._output_size])
  59. mse = self.ms_error(_pred, self._ys)
  60. mse_ave_across_batch = tf.reduce_mean(mse, 0)
  61. mse_sum_across_time = tf.reduce_sum(mse_ave_across_batch, 0)
  62. self._cost = mse_sum_across_time
  63. self._cost_ave_time = self._cost / self._time_steps
  64. with tf.variable_scope('trian'):
  65. self._lr = tf.convert_to_tensor(self._lr)
  66. self.train_op = tf.train.AdamOptimizer(self._lr).minimize(self._cost)
  67. -meta"> @staticmethod
  68. def ms_error(y_target, y_pre):
  69. return tf.square(tf.subtract(y_target, y_pre))
  70. -meta"> @staticmethod
  71. def _weight_variable(shape, name='weights'):
  72. initializer = tf.random_normal_initializer(mean=0., stddev=0.5, )
  73. return tf.get_variable(shape=shape, initializer=initializer, name=name)
  74. -meta"> @staticmethod
  75. def _bias_variable(shape, name='biases'):
  76. initializer = tf.constant_initializer(0.1)
  77. return tf.get_variable(name=name, shape=shape, initializer=initializer)
  78. if __name__ == '__main__':
  79. train_config = TrainConfig() 定义train_config
  80. test_config = TestConfig()
  81. the wrong method to reuse parameters in train rnn
  82. with tf.variable_scope('train_rnn'):
  83. train_rnn1 = RNN(train_config)
  84. with tf.variable_scope('test_rnn'):
  85. test_rnn1 = RNN(test_config)
  86. the right method to reuse parameters in train rnn
  87. 目的使train的RNN调用参数,然后利用variable_scope方法共享RNN,让test的RNN再次调用一样的参数,
  88. with tf.variable_scope('rnn') as scope:
  89. sess = tf.Session()
  90. train_rnn2 = RNN(train_config)
  91. scope.reuse_variables() 告诉TF想重复利用RNN的参数
  92. test_rnn2 = RNN(test_config)
  93. tf.initialize_all_variables() no long valid from
  94. 2017-03-02 if using tensorflow >= 0.12
  95. if int((tf.__version__).split('.')[1]) < 12 and int((tf.__version__).split('.')[0]) < 1:
  96. init = tf.initialize_all_variables()
  97. else:
  98. init = tf.global_variables_initializer()
  99. sess.run(init)

相关文章
TF之RNN:实现利用scope.reuse_variables()告诉TF想重复利用RNN的参数的案例
 

网站声明:如果转载,请联系本站管理员。否则一切后果自行承担。

本文链接:https://www.xckfsq.com/news/show.html?id=4231
赞同 0
评论 0 条
犹豫迎心情L1
粉丝 0 发表 14 + 关注 私信
上周热门
Kingbase用户权限管理  2008
信刻全自动光盘摆渡系统  1738
信刻国产化智能光盘柜管理系统  1407
银河麒麟添加网络打印机时,出现“client-error-not-possible”错误提示  1002
银河麒麟打印带有图像的文档时出错  906
银河麒麟添加打印机时,出现“server-error-internal-error”  698
麒麟系统也能完整体验微信啦!  645
统信桌面专业版【如何查询系统安装时间】  616
统信操作系统各版本介绍  607
统信桌面专业版【全盘安装UOS系统】介绍  582
本周热议
我的信创开放社区兼职赚钱历程 40
今天你签到了吗? 27
信创开放社区邀请他人注册的具体步骤如下 15
如何玩转信创开放社区—从小白进阶到专家 15
方德桌面操作系统 14
我有15积分有什么用? 13
用抖音玩法闯信创开放社区——用平台宣传企业产品服务 13
如何让你先人一步获得悬赏问题信息?(创作者必看) 12
2024中国信创产业发展大会暨中国信息科技创新与应用博览会 9
中央国家机关政府采购中心:应当将CPU、操作系统符合安全可靠测评要求纳入采购需求 8

添加我为好友,拉您入交流群!

请使用微信扫一扫!