Dataset之MNIST:自定义函数mnist.load_mnist根据网址下载mnist数据集(四个ubyte.gz格式数据集文件)


上海人
上海人 2022-09-19 15:21:40 53021
分类专栏: 资讯

Dataset之MNIST:自定义函数mnist.load_mnist根据网址下载mnist数据集(四个ubyte.gz格式数据集文件)

目录

下载结果

运行代码


下载结果

运行代码

mnist.py文件

  1. coding: utf-8
  2. try:
  3. import urllib.request
  4. except ImportError:
  5. raise ImportError('You should use Python 3.x')
  6. import os.path
  7. import gzip
  8. import pickle
  9. import os
  10. import numpy as np
  11. url_base = 'http://yann.lecun.com/exdb/mnist/'
  12. key_file = {
  13. 'train_img':'train-images-idx3-ubyte.gz',
  14. 'train_label':'train-labels-idx1-ubyte.gz',
  15. 'test_img':'t10k-images-idx3-ubyte.gz',
  16. 'test_label':'t10k-labels-idx1-ubyte.gz'
  17. }
  18. dataset_dir = os.path.dirname(os.path.abspath(__file__))
  19. save_file = dataset_dir + "/mnist.pkl"
  20. train_num = 60000
  21. test_num = 10000
  22. img_dim = (1, 28, 28)
  23. img_size = 784
  24. def _download(file_name):
  25. file_path = dataset_dir + "/" + file_name
  26. if os.path.exists(file_path):
  27. return
  28. print("Downloading " + file_name + " ... ")
  29. urllib.request.urlretrieve(url_base + file_name, file_path)
  30. print("Done")
  31. def download_mnist():
  32. for v in key_file.values():
  33. _download(v)
  34. def _load_label(file_name):
  35. file_path = dataset_dir + "/" + file_name
  36. print("Converting " + file_name + " to NumPy Array ...")
  37. with gzip.open(file_path, 'rb') as f:
  38. labels = np.frombuffer(f.read(), np.uint8, offset=8)
  39. print("Done")
  40. return labels
  41. def _load_img(file_name):
  42. file_path = dataset_dir + "/" + file_name
  43. print("Converting " + file_name + " to NumPy Array ...")
  44. with gzip.open(file_path, 'rb') as f:
  45. data = np.frombuffer(f.read(), np.uint8, offset=16)
  46. data = data.reshape(-1, img_size)
  47. print("Done")
  48. return data
  49. def _convert_numpy():
  50. dataset = {}
  51. dataset['train_img'] = _load_img(key_file['train_img'])
  52. dataset['train_label'] = _load_label(key_file['train_label'])
  53. dataset['test_img'] = _load_img(key_file['test_img'])
  54. dataset['test_label'] = _load_label(key_file['test_label'])
  55. return dataset
  56. def init_mnist():
  57. download_mnist()
  58. dataset = _convert_numpy()
  59. print("Creating pickle file ...")
  60. with open(save_file, 'wb') as f:
  61. pickle.dump(dataset, f, -1)
  62. print("Done!")
  63. def _change_one_hot_label(X):
  64. T = np.zeros((X.size, 10))
  65. for idx, row in enumerate(T):
  66. row[X[idx]] = 1
  67. return T
  68. def load_mnist(normalize=True, flatten=True, one_hot_label=False):
  69. """读入MNIST数据集
  70. Parameters
  71. ----------
  72. normalize : 将图像的像素值正规化为0.0~1.0
  73. one_hot_label :
  74. one_hot_label为True的情况下,标签作为one-hot数组返回
  75. one-hot数组是指[0,0,1,0,0,0,0,0,0,0]这样的数组
  76. flatten : 是否将图像展开为一维数组
  77. Returns
  78. -------
  79. (训练图像, 训练标签), (测试图像, 测试标签)
  80. """
  81. if not os.path.exists(save_file):
  82. init_mnist()
  83. with open(save_file, 'rb') as f:
  84. dataset = pickle.load(f)
  85. if normalize:
  86. for key in ('train_img', 'test_img'):
  87. dataset[key] = dataset[key].astype(np.float32)
  88. dataset[key] /= 255.0
  89. if one_hot_label:
  90. dataset['train_label'] = _change_one_hot_label(dataset['train_label'])
  91. dataset['test_label'] = _change_one_hot_label(dataset['test_label'])
  92. if not flatten:
  93. for key in ('train_img', 'test_img'):
  94. dataset[key] = dataset[key].reshape(-1, 1, 28, 28)
  95. return (dataset['train_img'], dataset['train_label']), (dataset['test_img'], dataset['test_label'])
  96. if __name__ == '__main__':
  97. init_mnist()

网站声明:如果转载,请联系本站管理员。否则一切后果自行承担。

本文链接:https://www.xckfsq.com/news/show.html?id=3245
赞同 0
评论 0 条
上海人L0
粉丝 0 发表 11 + 关注 私信
上周热门
如何使用 StarRocks 管理和优化数据湖中的数据?  2951
【软件正版化】软件正版化工作要点  2872
统信UOS试玩黑神话:悟空  2833
信刻光盘安全隔离与信息交换系统  2728
镜舟科技与中启乘数科技达成战略合作,共筑数据服务新生态  1261
grub引导程序无法找到指定设备和分区  1226
华为全联接大会2024丨软通动力分论坛精彩议程抢先看!  165
2024海洋能源产业融合发展论坛暨博览会同期活动-海洋能源与数字化智能化论坛成功举办  163
点击报名 | 京东2025校招进校行程预告  163
华为纯血鸿蒙正式版9月底见!但Mate 70的内情还得接着挖...  159
本周热议
我的信创开放社区兼职赚钱历程 40
今天你签到了吗? 27
如何玩转信创开放社区—从小白进阶到专家 15
信创开放社区邀请他人注册的具体步骤如下 15
方德桌面操作系统 14
用抖音玩法闯信创开放社区——用平台宣传企业产品服务 13
我有15积分有什么用? 13
如何让你先人一步获得悬赏问题信息?(创作者必看) 12
2024中国信创产业发展大会暨中国信息科技创新与应用博览会 9
中央国家机关政府采购中心:应当将CPU、操作系统符合安全可靠测评要求纳入采购需求 8

加入交流群

请使用微信扫一扫!