PyTorch:采用sklearn 工具生成这样的合成数据集+利用PyTorch实现简单合成数据集上的线性回归进行数据分析


钥匙俭朴
钥匙俭朴 2022-09-20 10:01:56 52728
分类专栏: 资讯

PyTorch:采用sklearn 工具生成这样的合成数据集+利用PyTorch实现简单合成数据集上的线性回归进行数据分析

目录

输出结果

核心代码


输出结果

核心代码

  1. PyTorch:采用sklearn 工具生成这样的合成数据集+利用PyTorch实现简单合成数据集上的线性回归进行数据分析
  2. from sklearn.datasets import make_regression
  3. import seaborn as sns
  4. import pandas as pd
  5. import matplotlib.pyplot as plt
  6. sns.set()
  7. x_train, y_train, W_target = make_regression(n_samples=100, n_features=1, noise=10, coef = True)
  8. df = pd.DataFrame(data = {'X':x_train.ravel(), 'Y':y_train.ravel()})
  9. sns.lmplot(x='X', y='Y', data=df, fit_reg=True)
  10. plt.show()
  11. x_torch = torch.FloatTensor(x_train)
  12. y_torch = torch.FloatTensor(y_train)
  13. y_torch = y_torch.view(y_torch.size()[0], 1)
  14. class LinearRegression(torch.nn.Module): 定义LR的类。torch.nn库构建模型
  15. PyTorch 的 nn 库中有大量有用的模块,其中一个就是线性模块。如名字所示,它对输入执行线性变换,即线性回归。
  16. def __init__(self, input_size, output_size):
  17. super(LinearRegression, self).__init__()
  18. self.linear = torch.nn.Linear(input_size, output_size)
  19. def forward(self, x):
  20. return self.linear(x)
  21. model = LinearRegression(1, 1)
  22. criterion = torch.nn.MSELoss() 训练线性回归,我们需要从 nn 库中添加合适的损失函数。对于线性回归,我们将使用 MSELoss()——均方差损失函数
  23. optimizer = torch.optim.SGD(model.parameters(), lr=0.1)还需要使用优化函数(SGD),并运行与之前示例类似的反向传播。本质上,我们重复上文定义的 train() 函数中的步骤。
  24. 不能直接使用该函数的原因是我们实现它的目的是分类而不是回归,以及我们使用交叉熵损失和最大元素的索引作为模型预测。而对于线性回归,我们使用线性层的输出作为预测。
  25. for epoch in range(50):
  26. data, target = Variable(x_torch), Variable(y_torch)
  27. output = model(data)
  28. optimizer.zero_grad()
  29. loss = criterion(output, target)
  30. loss.backward()
  31. optimizer.step()
  32. predicted = model(Variable(x_torch)).data.numpy()
  33. 打印出原始数据和适合 PyTorch 的线性回归
  34. plt.plot(x_train, y_train, 'o', label='Original data')
  35. plt.plot(x_train, predicted, label='Fitted line')
  36. plt.legend()
  37. plt.title(u'Py:PyTorch实现简单合成数据集上的线性回归进行数据分析')
  38. plt.show()

网站声明:如果转载,请联系本站管理员。否则一切后果自行承担。

本文链接:https://www.xckfsq.com/news/show.html?id=4008
赞同 0
评论 0 条
钥匙俭朴L0
粉丝 0 发表 6 + 关注 私信
上周热门
如何使用 StarRocks 管理和优化数据湖中的数据?  2941
【软件正版化】软件正版化工作要点  2860
统信UOS试玩黑神话:悟空  2819
信刻光盘安全隔离与信息交换系统  2712
镜舟科技与中启乘数科技达成战略合作,共筑数据服务新生态  1246
grub引导程序无法找到指定设备和分区  1213
华为全联接大会2024丨软通动力分论坛精彩议程抢先看!  163
点击报名 | 京东2025校招进校行程预告  162
2024海洋能源产业融合发展论坛暨博览会同期活动-海洋能源与数字化智能化论坛成功举办  160
华为纯血鸿蒙正式版9月底见!但Mate 70的内情还得接着挖...  157
本周热议
我的信创开放社区兼职赚钱历程 40
今天你签到了吗? 27
信创开放社区邀请他人注册的具体步骤如下 15
如何玩转信创开放社区—从小白进阶到专家 15
方德桌面操作系统 14
我有15积分有什么用? 13
用抖音玩法闯信创开放社区——用平台宣传企业产品服务 13
如何让你先人一步获得悬赏问题信息?(创作者必看) 12
2024中国信创产业发展大会暨中国信息科技创新与应用博览会 9
中央国家机关政府采购中心:应当将CPU、操作系统符合安全可靠测评要求纳入采购需求 8

加入交流群

请使用微信扫一扫!