TF之LiR:基于tensorflow实现机器学习之线性回归算法


伶俐就金针菇
伶俐就金针菇 2022-09-20 11:13:52 52910
分类专栏: 资讯

TF之LiR:基于tensorflow实现机器学习之线性回归算法

目录

输出结果

代码设计


输出结果


代码设计

  1. -*- coding: utf-8 -*-
  2. TF之LiR:基于tensorflow实现机器学习之线性回归算法
  3. import tensorflow as tf
  4. import numpy
  5. import matplotlib.pyplot as plt
  6. rng =numpy.random
  7. 参数设定
  8. learning_rate=0.01
  9. training_epochs=10000
  10. display_step=50 每隔50次迭代输出一次
  11. 训练数据
  12. train_X=numpy.asarray([……])
  13. train_Y=numpy.asarray([……])
  14. n_samples=train_X.shape[0]
  15. print("train_X:",train_X)
  16. print("train_Y:",train_Y)
  17. 设置placeholder
  18. X=tf.placeholder("float")
  19. Y=tf.placeholder("float")
  20. 设置模型的权重和偏置,因为是不断更新的所以采用Variable定义
  21. W=tf.Variable(rng.randn(),name="weight")
  22. b=tf.Variable(rng.randn(),name="bias")
  23. 设置线性回归方程LiR:w*x+b
  24. pred=tf.add(tf.multiply(X,W),b)
  25. cost=tf.reduce_sum(tf.pow(pred-Y,2))/(2*n_samples) 设置cost为均方差即reduce_sum函数
  26. optimizer=tf.train.GradientDescentOptimizer(learning_rate).minimize(cost) 梯度下降,minimize函数默认下自动修正w和b
  27. init=tf.global_variables_initializer() 在session运算时初始化所有变量
  28. 开始训练
  29. with tf.Session() as sess:
  30. sess.run(init) 运行一下初始化的变量
  31. for epoch in range(training_epochs): 输入所有训练数据
  32. for(x,y) in zip(train_X,train_Y):
  33. sess.run(optimizer,feed_dict={X:x,Y:y})
  34. 打印出每次迭代的log日志,每隔50个打印一次
  35. if (epoch+1) % display_step ==0:
  36. c=sess.run(cost,feed_dict={X:train_X,Y:train_Y})
  37. print("迭代次数Epoch:","%04d" % (epoch+1),"下降值cost=","{:.9f}".format(c),
  38. "W=",sess.run(W),"b=",sess.run(b))
  39. print("Optimizer Finished!")
  40. training_cost=sess.run(cost,feed_dict={X:train_X,Y:train_Y})
  41. print("Training cost=",training_cost,"W=",sess.run(W),"b=",sess.run(b))
  42. 绘图
  43. plt.rcParams['font.sans-serif']=['SimHei']
  44. plt.subplot(121)
  45. plt.plot(train_X, train_Y, 'ro', label='Original data')
  46. plt.plot(train_X, sess.run(W) * train_X + sess.run(b), label='Fitted line')
  47. plt.legend()
  48. plt.title("TF之LiR:Original data")
  49. 测试样本
  50. test_X = numpy.asarray([6.83, 4.668, 8.9, 7.91, 5.7, 8.7, 3.1, 2.1])
  51. test_Y = numpy.asarray([1.84, 2.273, 3.2, 2.831,2.92, 3.24, 1.35, 1.03])
  52. print("Testing... (Mean square loss Comparison)")
  53. testing_cost = sess.run(tf.reduce_sum(tf.pow(pred - Y, 2)) / (2 * test_X.shape[0]),
  54. feed_dict={X:test_X,Y:test_Y}) same function as cost above
  55. print("Testing cost=", testing_cost)
  56. print("Absolute mean square loss difference:", abs( training_cost - testing_cost))
  57. 绘图
  58. plt.subplot(122)
  59. plt.plot(test_X, test_Y, 'bo', label='Testing data')
  60. plt.plot(train_X, sess.run(W) * train_X + sess.run(b), label='Fitted line')
  61. plt.legend()
  62. plt.title("TF之LiR:Testing data")
  63. plt.show()
  1. 迭代次数Epoch: 6300 下降值cost= 0.076938324 W= 0.25199208 b= 0.8008495
  2. ……
  3. 迭代次数Epoch: 10000 下降值cost= 0.076965131 W= 0.24998894 b= 0.80145526
  4. 迭代次数Epoch: 10000 下降值cost= 0.076942705 W= 0.25047526 b= 0.80151606
  5. 迭代次数Epoch: 10000 下降值cost= 0.076929517 W= 0.25114807 b= 0.801635
  6. 迭代次数Epoch: 10000 下降值cost= 0.076958008 W= 0.25011322 b= 0.8015234
  7. 迭代次数Epoch: 10000 下降值cost= 0.076990739 W= 0.24960834 b= 0.80136055
  8. Optimizer Finished!
  9. Training cost= 0.07699074 W= 0.24960834 b= 0.80136055
  10. Testing... (Mean square loss Comparison)
  11. Testing cost= 0.07910849
  12. Absolute mean square loss difference: 0.002117753

网站声明:如果转载,请联系本站管理员。否则一切后果自行承担。

本文链接:https://www.xckfsq.com/news/show.html?id=4238
赞同 0
评论 0 条
伶俐就金针菇L0
粉丝 0 发表 10 + 关注 私信
上周热门
如何使用 StarRocks 管理和优化数据湖中的数据?  2959
【软件正版化】软件正版化工作要点  2878
统信UOS试玩黑神话:悟空  2843
信刻光盘安全隔离与信息交换系统  2737
镜舟科技与中启乘数科技达成战略合作,共筑数据服务新生态  1270
grub引导程序无法找到指定设备和分区  1235
华为全联接大会2024丨软通动力分论坛精彩议程抢先看!  165
点击报名 | 京东2025校招进校行程预告  164
2024海洋能源产业融合发展论坛暨博览会同期活动-海洋能源与数字化智能化论坛成功举办  163
华为纯血鸿蒙正式版9月底见!但Mate 70的内情还得接着挖...  159
本周热议
我的信创开放社区兼职赚钱历程 40
今天你签到了吗? 27
信创开放社区邀请他人注册的具体步骤如下 15
如何玩转信创开放社区—从小白进阶到专家 15
方德桌面操作系统 14
我有15积分有什么用? 13
用抖音玩法闯信创开放社区——用平台宣传企业产品服务 13
如何让你先人一步获得悬赏问题信息?(创作者必看) 12
2024中国信创产业发展大会暨中国信息科技创新与应用博览会 9
中央国家机关政府采购中心:应当将CPU、操作系统符合安全可靠测评要求纳入采购需求 8

加入交流群

请使用微信扫一扫!