TF之LSTM:利用基于顺序的LSTM回归算法对DIY数据集sin曲线(蓝虚)预测cos(红实)(matplotlib动态演示)


2008年的第一场雪
2008年的第一场雪 2022-09-20 11:13:04 49753
分类专栏: 资讯

TF之LSTM:利用基于顺序的LSTM回归算法对DIY数据集sin曲线(蓝虚)预测cos(红实)(matplotlib动态演示)

目录

输出结果

代码设计


输出结果

更新……


代码设计

  1. import tensorflow as tf
  2. import numpy as np
  3. import matplotlib.pyplot as plt
  4. BATCH_START = 0
  5. TIME_STEPS = 20
  6. BATCH_SIZE = 50
  7. INPUT_SIZE = 1
  8. OUTPUT_SIZE = 1
  9. CELL_SIZE = 10
  10. LR = 0.006
  11. BATCH_START_TEST = 0
  12. def get_batch():
  13. global BATCH_START, TIME_STEPS
  14. xs shape (50batch, 20steps)
  15. xs = np.arange(BATCH_START, BATCH_START+TIME_STEPS*BATCH_SIZE).reshape((BATCH_SIZE, TIME_STEPS)) / (10*np.pi)
  16. seq = np.sin(xs)
  17. res = np.cos(xs)
  18. BATCH_START += TIME_STEPS
  19. return [seq[:, :, np.newaxis], res[:, :, np.newaxis], xs]
  20. class LSTMRNN(-title class_ inherited__">object):
  21. def __init__(self, n_steps, input_size, output_size, cell_size, batch_size):
  22. self.n_steps = n_steps
  23. self.input_size = input_size
  24. self.output_size = output_size
  25. self.cell_size = cell_size
  26. self.batch_size = batch_size
  27. with tf.name_scope('inputs'):
  28. self.xs = tf.placeholder(tf.float32, [None, n_steps, input_size], name='xs')
  29. self.ys = tf.placeholder(tf.float32, [None, n_steps, output_size], name='ys')
  30. with tf.variable_scope('in_hidden'):
  31. self.add_input_layer()
  32. with tf.variable_scope('LSTM_cell'):
  33. self.add_cell()
  34. with tf.variable_scope('out_hidden'):
  35. self.add_output_layer()
  36. with tf.name_scope('cost'):
  37. self.compute_cost()
  38. with tf.name_scope('train'):
  39. self.train_op = tf.train.AdamOptimizer(LR).minimize(self.cost)
  40. def add_input_layer(self,):
  41. l_in_x = tf.reshape(self.xs, [-1, self.input_size], name='2_2D')
  42. Ws_in = self._weight_variable([self.input_size, self.cell_size])
  43. bs_in = self._bias_variable([self.cell_size,])
  44. with tf.name_scope('Wx_plus_b'):
  45. l_in_y = tf.matmul(l_in_x, Ws_in) + bs_in
  46. self.l_in_y = tf.reshape(l_in_y, [-1, self.n_steps, self.cell_size], name='2_3D')
  47. def add_cell(self):
  48. lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(self.cell_size, forget_bias=1.0, state_is_tuple=True)
  49. with tf.name_scope('initial_state'):
  50. self.cell_init_state = lstm_cell.zero_state(self.batch_size, dtype=tf.float32)
  51. self.cell_outputs, self.cell_final_state = tf.nn.dynamic_rnn(
  52. lstm_cell, self.l_in_y, initial_state=self.cell_init_state, time_major=False)
  53. def add_output_layer(self):
  54. l_out_x = tf.reshape(self.cell_outputs, [-1, self.cell_size], name='2_2D')
  55. Ws_out = self._weight_variable([self.cell_size, self.output_size])
  56. bs_out = self._bias_variable([self.output_size, ])
  57. with tf.name_scope('Wx_plus_b'):
  58. self.pred = tf.matmul(l_out_x, Ws_out) + bs_out
  59. def compute_cost(self):
  60. losses = tf.contrib.legacy_seq2seq.sequence_loss_by_example(
  61. [tf.reshape(self.pred, [-1], name='reshape_pred')],
  62. [tf.reshape(self.ys, [-1], name='reshape_target')],
  63. [tf.ones([self.batch_size * self.n_steps], dtype=tf.float32)],
  64. average_across_timesteps=True,
  65. softmax_loss_function=self.ms_error,
  66. name='losses'
  67. )
  68. with tf.name_scope('average_cost'):
  69. self.cost = tf.div(
  70. tf.reduce_sum(losses, name='losses_sum'),
  71. self.batch_size,
  72. name='average_cost')
  73. tf.summary.scalar('cost', self.cost)
  74. def ms_error(self, y_target, y_pre):
  75. return tf.square(tf.sub(y_target, y_pre))
  76. def _weight_variable(self, shape, name='weights'):
  77. initializer = tf.random_normal_initializer(mean=0., stddev=1.,)
  78. return tf.get_variable(shape=shape, initializer=initializer, name=name)
  79. def _bias_variable(self, shape, name='biases'):
  80. initializer = tf.constant_initializer(0.1)
  81. return tf.get_variable(name=name, shape=shape, initializer=initializer)
  82. if __name__ == '__main__':
  83. model = LSTMRNN(TIME_STEPS, INPUT_SIZE, OUTPUT_SIZE, CELL_SIZE, BATCH_SIZE)
  84. sess = tf.Session()
  85. merged=tf.summary.merge_all()
  86. writer=tf.summary.FileWriter("niu0127/logs0127",sess.graph)
  87. sess.run(tf.initialize_all_variables())
  88. plt.ion()
  89. plt.show()
  90. for i in range(200):
  91. seq, res, xs = get_batch()
  92. if i == 0:
  93. feed_dict = {
  94. model.xs: seq,
  95. model.ys: res,
  96. }
  97. else:
  98. feed_dict = {
  99. model.xs: seq,
  100. model.ys: res,
  101. model.cell_init_state: state
  102. }
  103. _, cost, state, pred = sess.run(
  104. [model.train_op, model.cost, model.cell_final_state, model.pred],
  105. feed_dict=feed_dict)
  106. plt.plot(xs[0,:],res[0].flatten(),'r',xs[0,:],pred.flatten()[:TIME_STEPS],'g--')
  107. plt.title('Matplotlib,RNN,Efficient learning,Approach,Cosx --Jason Niu')
  108. plt.ylim((-1.2,1.2))
  109. plt.draw()
  110. plt.pause(0.1)

相关文章
TF之RNN:matplotlib动态演示之基于顺序的RNN回归案例实现高效学习逐步逼近余弦曲线

 

网站声明:如果转载,请联系本站管理员。否则一切后果自行承担。

本文链接:https://www.xckfsq.com/news/show.html?id=4235
赞同 0
评论 0 条
2008年的第一场雪L0
粉丝 0 发表 8 + 关注 私信
上周热门
如何使用 StarRocks 管理和优化数据湖中的数据?  2950
【软件正版化】软件正版化工作要点  2872
统信UOS试玩黑神话:悟空  2833
信刻光盘安全隔离与信息交换系统  2728
镜舟科技与中启乘数科技达成战略合作,共筑数据服务新生态  1261
grub引导程序无法找到指定设备和分区  1225
华为全联接大会2024丨软通动力分论坛精彩议程抢先看!  165
2024海洋能源产业融合发展论坛暨博览会同期活动-海洋能源与数字化智能化论坛成功举办  163
点击报名 | 京东2025校招进校行程预告  163
华为纯血鸿蒙正式版9月底见!但Mate 70的内情还得接着挖...  158
本周热议
我的信创开放社区兼职赚钱历程 40
今天你签到了吗? 27
如何玩转信创开放社区—从小白进阶到专家 15
信创开放社区邀请他人注册的具体步骤如下 15
方德桌面操作系统 14
用抖音玩法闯信创开放社区——用平台宣传企业产品服务 13
我有15积分有什么用? 13
如何让你先人一步获得悬赏问题信息?(创作者必看) 12
2024中国信创产业发展大会暨中国信息科技创新与应用博览会 9
中央国家机关政府采购中心:应当将CPU、操作系统符合安全可靠测评要求纳入采购需求 8

加入交流群

请使用微信扫一扫!