TF之LSTM:基于Tensorflow框架采用PTB数据集建立LSTM网络的自然语言建模


pingan
平安喜乐 2022-09-19 17:25:00 53795
分类专栏: 资讯

TF之LSTM:基于Tensorflow框架采用PTB数据集建立LSTM网络的自然语言建模

目录

关于PTB数据集

代码实现


关于PTB数据集

PTB (Penn Treebank Dataset)文本数据集是语言模型学习中目前最被广泛使用数据集。
ptb.test.txt    测试集数据文件
ptb.train.txt   训练集数据文件
ptb.valid.txt   验证集数据文件
这三个数据文件中的数据已经经过了预处理,包含了10000 个不同的词语和语句结束标记符(在文本中就是换行符)以及标记稀有词语的特殊符号。
为了让使用PTB数据集更加方便,TensorFlow提供了两个函数来帮助实现数据的预处理。首先,TensorFlow提供了ptb_raw_data函数来读取PTB的原始数据,并将原始数据中的单词转化为单词ID。
训练数据中总共包含了929589 个单词,而这些单词被组成了一个非常长的序列。这个序列通过特殊的标识符给出了每句话结束的位置。在这个数据集中,句子结束的标识符ID为2。
数据集的下载地址:TF的PTB数据集     (别的数据集不匹配的话会出现错误)
 

代码实现

   本代码使用2层 LSTM 网络,且每层有 200 个隐藏单元。在训练中截断的输入序列长度为 32,且使用 Dropout 和梯度截断等方法控制模型的过拟合与梯度爆炸等问题。当简单地训练 3 个 Epoch 后,测试复杂度(Perplexity)降低到了 210,如果多轮训练会更低。

  1. -*- coding: utf-8 -*-
  2. from __future__ import absolute_import
  3. from __future__ import division
  4. from __future__ import print_function
  5. import collections
  6. import os
  7. import sys
  8. import tensorflow as tf
  9. Py3 = sys.version_info[0] == 3
  10. def _read_words(filename):
  11. with tf.gfile.GFile(filename, "r") as f:
  12. if Py3:
  13. return f.read().replace("\n", "<eos>").split()
  14. else:
  15. return f.read().decode("utf-8").replace("\n", "<eos>").split()
  16. def _build_vocab(filename):
  17. data = _read_words(filename)
  18. counter = collections.Counter(data)
  19. count_pairs = sorted(counter.items(), key=lambda x: (-x[1], x[0]))
  20. words, _ = list(zip(*count_pairs))
  21. word_to_id = dict(zip(words, range(len(words))))
  22. return word_to_id
  23. def _file_to_word_ids(filename, word_to_id):
  24. data = _read_words(filename)
  25. return [word_to_id[word] for word in data if word in word_to_id]
  26. def ptb_raw_data(data_path=None):
  27. """Load PTB raw data from data directory "data_path".
  28. Reads PTB text files, converts strings to integer ids,
  29. and performs mini-batching of the inputs.
  30. The PTB dataset comes from Tomas Mikolov's webpage:
  31. http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz
  32. Args:
  33. data_path: string path to the directory where simple-examples.tgz has
  34. been extracted.
  35. Returns:
  36. tuple (train_data, valid_data, test_data, vocabulary)
  37. where each of the data objects can be passed to PTBIterator.
  38. """
  39. train_path = os.path.join(data_path, "ptb.train.txt")
  40. valid_path = os.path.join(data_path, "ptb.valid.txt")
  41. test_path = os.path.join(data_path, "ptb.test.txt")
  42. word_to_id = _build_vocab(train_path)
  43. train_data = _file_to_word_ids(train_path, word_to_id)
  44. valid_data = _file_to_word_ids(valid_path, word_to_id)
  45. test_data = _file_to_word_ids(test_path, word_to_id)
  46. vocabulary = len(word_to_id)
  47. return train_data, valid_data, test_data, vocabulary
  48. def ptb_producer(raw_data, batch_size, num_steps, name=None):
  49. """Iterate on the raw PTB data.
  50. This chunks up raw_data into batches of examples and returns Tensors that
  51. are drawn from these batches.
  52. Args:
  53. raw_data: one of the raw data outputs from ptb_raw_data.
  54. batch_size: int, the batch size.
  55. num_steps: int, the number of unrolls.
  56. name: the name of this operation (optional).
  57. Returns:
  58. A pair of Tensors, each shaped [batch_size, num_steps]. The second element
  59. of the tuple is the same data time-shifted to the right by one.
  60. Raises:
  61. tf.errors.InvalidArgumentError: if batch_size or num_steps are too high.
  62. """
  63. with tf.name_scope(name, "PTBProducer", [raw_data, batch_size, num_steps]):
  64. raw_data = tf.convert_to_tensor(raw_data, name="raw_data", dtype=tf.int32)
  65. data_len = tf.size(raw_data)
  66. batch_len = data_len // batch_size
  67. data = tf.reshape(raw_data[0 : batch_size * batch_len],
  68. [batch_size, batch_len])
  69. epoch_size = (batch_len - 1) // num_steps
  70. assertion = tf.assert_positive(
  71. epoch_size,
  72. message="epoch_size == 0, decrease batch_size or num_steps")
  73. with tf.control_dependencies([assertion]):
  74. epoch_size = tf.identity(epoch_size, name="epoch_size")
  75. i = tf.train.range_input_producer(epoch_size, shuffle=False).dequeue()
  76. x = tf.strided_slice(data, [0, i * num_steps],
  77. [batch_size, (i + 1) * num_steps])
  78. x.set_shape([batch_size, num_steps])
  79. y = tf.strided_slice(data, [0, i * num_steps + 1],
  80. [batch_size, (i + 1) * num_steps + 1])
  81. y.set_shape([batch_size, num_steps])
  82. return x, y
  1. from reader import *
  2. import tensorflow as tf
  3. import numpy as np
  4. data_path = 'F:/File_Python/Python_daydayup/data/simple-examples/data' F:/File_Python/Python_daydayup/data/simple-examples/data
  5. 隐藏层单元数与LSTM层级数
  6. hidden_size = 200
  7. num_layers = 2
  8. 词典规模
  9. vocab_size = 10000
  10. learning_rate = 1.0
  11. train_batch_size = 16
  12. 训练数据截断长度
  13. train_num_step = 32
  14. 在测试时不需要使用截断,测试数据为一个超长序列
  15. eval_batch_size = 1
  16. eval_num_step = 1
  17. num_epoch = 3
  18. 结点不被Dropout的概率
  19. keep_prob = 0.5
  20. 用于控制梯度爆炸的参数
  21. max_grad_norm = 5
  22. 通过ptbmodel 的类描述模型
  23. class PTBModel(-title class_ inherited__">object):
  24. def __init__(self, is_training, batch_size, num_steps):
  25. 记录使用的Batch大小和截断长度
  26. self.batch_size = batch_size
  27. self.num_steps = num_steps
  28. 定义输入层,维度为批量大小×截断长度
  29. self.input_data = tf.placeholder(tf.int32, [batch_size, num_steps])
  30. 定义预期输出
  31. self.targets = tf.placeholder(tf.int32, [batch_size, num_steps])
  32. 定义使用LSTM结构为循环体,带Dropout的深度RNN
  33. lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(hidden_size)
  34. if is_training:
  35. lstm_cell = tf.nn.rnn_cell.DropoutWrapper(lstm_cell, output_keep_prob=keep_prob)
  36. cell = tf.nn.rnn_cell.MultiRNNCell([lstm_cell] * num_layers)
  37. 初始化状态为0
  38. self.initial_state = cell.zero_state(batch_size, tf.float32)
  39. 将单词ID转换为单词向量,embedding的维度为vocab_size*hidden_size
  40. embedding = tf.get_variable('embedding', [vocab_size, hidden_size])
  41. 将一个批量内的单词ID转化为词向量,转化后的输入维度为批量大小×截断长度×隐藏单元数
  42. inputs = tf.nn.embedding_lookup(embedding, self.input_data)
  43. 只在训练时使用Dropout
  44. if is_training: inputs = tf.nn.dropout(inputs, keep_prob)
  45. 定义输出列表,这里先将不同时刻LSTM的输出收集起来,再通过全连接层得到最终输出
  46. outputs = []
  47. state 储存不同批量中LSTM的状态,初始为0
  48. state = self.initial_state
  49. with tf.variable_scope('RNN'):
  50. for time_step in range(num_steps):
  51. if time_step > 0: tf.get_variable_scope().reuse_variables()
  52. 从输入数据获取当前时间步的输入与前一时间步的状态,并传入LSTM结构
  53. cell_output, state = cell(inputs[:, time_step, :], state)
  54. 将当前输出加入输出队列
  55. outputs.append(cell_output)
  56. 将输出队列展开成[batch,hidden*num_step]的形状,再reshape为[batch*num_step, hidden]
  57. output = tf.reshape(tf.concat(outputs, 1), [-1, hidden_size])
  58. 将LSTM的输出传入全连接层以生成最后的预测结果。最后结果在每时刻上都是长度为vocab_size的张量
  59. 且经过softmax层后表示下一个位置不同词的概率
  60. weight = tf.get_variable('weight', [hidden_size, vocab_size])
  61. bias = tf.get_variable('bias', [vocab_size])
  62. logits = tf.matmul(output, weight) + bias
  63. 定义交叉熵损失函数,一个序列的交叉熵之和
  64. loss = tf.contrib.legacy_seq2seq.sequence_loss_by_example(
  65. [logits], 预测的结果
  66. [tf.reshape(self.targets, [-1])], 期望正确的结果,这里将[batch_size, num_steps]压缩为一维张量
  67. [tf.ones([batch_size * num_steps], dtype=tf.float32)]) 损失的权重,所有为1表明不同批量和时刻的重要程度一样
  68. 计算每个批量的平均损失
  69. self.cost = tf.reduce_sum(loss) / batch_size
  70. self.final_state = state
  71. 只在训练模型时定义反向传播操作
  72. if not is_training: return
  73. trainable_variable = tf.trainable_variables()
  74. 控制梯度爆炸问题
  75. grads, _ = tf.clip_by_global_norm(tf.gradients(self.cost, trainable_variable), max_grad_norm)
  76. 如果需要使用Adam作为优化器,可以改为tf.train.AdamOptimizer(learning_rate),学习率需要降低至0.001左右
  77. optimizer = tf.train.GradientDescentOptimizer(learning_rate)
  78. 定义训练步骤
  79. self.train_op = optimizer.apply_gradients(zip(grads, trainable_variable))
  80. def run_epoch(session, model, data, train_op, output_log, epoch_size):
  81. total_costs = 0.0
  82. iters = 0
  83. state = session.run(model.initial_state)
  84. 使用当前数据训练或测试模型
  85. for step in range(epoch_size):
  86. x, y = session.run(data)
  87. 在当前批量上运行train_op并计算损失值,交叉熵计算的是下一个单词为给定单词的概率
  88. cost, state, _ = session.run([model.cost, model.final_state, train_op],
  89. {model.input_data: x, model.targets: y, model.initial_state: state})
  90. 将不同时刻和批量的概率就可得到困惑度的对数形式,将这个和做指数运算就可得到困惑度
  91. total_costs += cost
  92. iters += model.num_steps
  93. 只在训练时输出日志
  94. if output_log and step % 100 == 0:
  95. print("After %d steps, perplexity is %.3f" % (step, np.exp(total_costs / iters)))
  96. return np.exp(total_costs / iters)
  97. def main():
  98. train_data, valid_data, test_data, _ = ptb_raw_data(data_path)
  99. 计算一个epoch需要训练的次数
  100. train_data_len = len(train_data)
  101. train_batch_len = train_data_len // train_batch_size
  102. train_epoch_size = (train_batch_len - 1) // train_num_step
  103. valid_data_len = len(valid_data)
  104. valid_batch_len = valid_data_len // eval_batch_size
  105. valid_epoch_size = (valid_batch_len - 1) // eval_num_step
  106. test_data_len = len(test_data)
  107. test_batch_len = test_data_len // eval_batch_size
  108. test_epoch_size = (test_batch_len - 1) // eval_num_step
  109. initializer = tf.random_uniform_initializer(-0.05, 0.05)
  110. with tf.variable_scope("language_model", reuse=None, initializer=initializer):
  111. train_model = PTBModel(True, train_batch_size, train_num_step)
  112. with tf.variable_scope("language_model", reuse=True, initializer=initializer):
  113. eval_model = PTBModel(False, eval_batch_size, eval_num_step)
  114. 训练模型。
  115. with tf.Session() as session:
  116. tf.global_variables_initializer().run()
  117. train_queue = ptb_producer(train_data, train_model.batch_size, train_model.num_steps)

网站声明:如果转载,请联系本站管理员。否则一切后果自行承担。

本文链接:https://www.xckfsq.com/news/show.html?id=3838
赞同 0
评论 0 条
平安喜乐L0
粉丝 0 发表 6 + 关注 私信
上周热门
如何使用 StarRocks 管理和优化数据湖中的数据?  2935
【软件正版化】软件正版化工作要点  2854
统信UOS试玩黑神话:悟空  2811
信刻光盘安全隔离与信息交换系统  2702
镜舟科技与中启乘数科技达成战略合作,共筑数据服务新生态  1235
grub引导程序无法找到指定设备和分区  1205
点击报名 | 京东2025校招进校行程预告  162
华为全联接大会2024丨软通动力分论坛精彩议程抢先看!  160
2024海洋能源产业融合发展论坛暨博览会同期活动-海洋能源与数字化智能化论坛成功举办  156
金山办公2024算法挑战赛 | 报名截止日期更新  153
本周热议
我的信创开放社区兼职赚钱历程 40
今天你签到了吗? 27
信创开放社区邀请他人注册的具体步骤如下 15
如何玩转信创开放社区—从小白进阶到专家 15
方德桌面操作系统 14
我有15积分有什么用? 13
用抖音玩法闯信创开放社区——用平台宣传企业产品服务 13
如何让你先人一步获得悬赏问题信息?(创作者必看) 12
2024中国信创产业发展大会暨中国信息科技创新与应用博览会 9
中央国家机关政府采购中心:应当将CPU、操作系统符合安全可靠测评要求纳入采购需求 8

加入交流群

请使用微信扫一扫!