DL之CNN:基于CNN-RNN(GRU,2)算法(keras+tensorflow)实现不定长文本识别


中兴通讯招聘2
中兴通讯招聘2 2022-09-19 14:51:33 52055
分类专栏: 资讯

DL之CNN:基于CNN-RNN(GRU,2)算法(keras+tensorflow)实现不定长文本识别

目录

输出结果

实现代码


输出结果

后期更新……

实现代码

后期更新……

image_ocr代码DL之CNN:利用CNN(keras, CTC loss, {image_ocr})算法实现OCR光学字符识别

  1. DL之CNN:基于CNN-RNN(GRU,2)算法(keras+tensorflow)实现不定长文本识别
  2. Keras 的 CTC loss函数:位于 https://github.com/fchollet/keras/blob/master/keras/backend/tensorflow_backend.py文件中,内容如下:
  3. import tensorflow as tf
  4. from tensorflow.python.ops import ctc_ops as ctc
  5. def ctc_batch_cost(y_true, y_pred, input_length, label_length):
  6. """Runs CTC loss algorithm on each batch element.
  7. Arguments
  8. y_true: tensor `(samples, max_string_length)`
  9. containing the truth labels.
  10. y_pred: tensor `(samples, time_steps, num_categories)`
  11. containing the prediction, or output of the softmax.
  12. input_length: tensor `(samples, 1)` containing the sequence length for
  13. each batch item in `y_pred`.
  14. label_length: tensor `(samples, 1)` containing the sequence length for
  15. each batch item in `y_true`.
  16. Returns
  17. Tensor with shape (samples,1) containing the
  18. CTC loss of each element.
  19. """
  20. label_length = tf.to_int32(tf.squeeze(label_length))
  21. input_length = tf.to_int32(tf.squeeze(input_length))
  22. sparse_labels = tf.to_int32(ctc_label_dense_to_sparse(y_true, label_length))
  23. y_pred = tf.log(tf.transpose(y_pred, perm=[1, 0, 2]) + 1e-8)
  24. return tf.expand_dims(ctc.ctc_loss(inputs=y_pred, labels=sparse_labels, sequence_length=input_length), 1)
  25. 不定长文本识别
  26. import os
  27. import itertools
  28. import re
  29. import datetime
  30. import cairocffi as cairo
  31. import editdistance
  32. import numpy as np
  33. from scipy import ndimage
  34. import pylab
  35. from keras import backend as K
  36. from keras.layers.convolutional import Conv2D, MaxPooling2D
  37. from keras.layers import Input, Dense, Activation, Reshape, Lambda
  38. from keras.layers.merge import add, concatenate
  39. from keras.layers.recurrent import GRU
  40. from keras.models import Model
  41. from keras.optimizers import SGD
  42. from keras.utils.data_utils import get_file
  43. from keras.preprocessing import image
  44. from keras.callbacks import EarlyStopping,Callback
  45. from keras.backend.tensorflow_backend import set_session
  46. import tensorflow as tf
  47. import matplotlib.pyplot as plt
  48. config = tf.ConfigProto()
  49. config.gpu_options.allow_growth=True
  50. set_session(tf.Session(config=config))
  51. OUTPUT_DIR = 'image_ocr'
  52. np.random.seed(55)
  53. 从 Keras 官方文件中 import 相关的函数
  54. !wget https://raw.githubusercontent.com/fchollet/keras/master/examples/image_ocr.py
  55. from image_ocr import *
  56. 定义必要的参数:
  57. run_name = datetime.datetime.now().strftime('%Y:%m:%d:%H:%M:%S')
  58. start_epoch = 0
  59. stop_epoch = 200
  60. img_w = 128
  61. img_h = 64
  62. words_per_epoch = 16000
  63. val_split = 0.2
  64. val_words = int(words_per_epoch * (val_split))
  65. Network parameters
  66. conv_filters = 16
  67. kernel_size = (3, 3)
  68. pool_size = 2
  69. time_dense_size = 32
  70. rnn_size = 512
  71. input_shape = (img_w, img_h, 1)
  72. 使用这些函数以及对应参数构建生成器,生成不固定长度的验证码
  73. fdir = os.path.dirname(get_file('wordlists.tgz', origin='http://www.mythic-ai.com/datasets/wordlists.tgz', untar=True))
  74. img_gen = TextImageGenerator(monogram_file=os.path.join(fdir, 'wordlist_mono_clean.txt'),
  75. bigram_file=os.path.join(fdir, 'wordlist_bi_clean.txt'),
  76. minibatch_size=32, img_w=img_w, img_h=img_h,
  77. downsample_factor=(pool_size ** 2), val_split=words_per_epoch - val_words )
  78. 构建CNN网络
  79. act = 'relu'
  80. input_data = Input(name='the_input', shape=input_shape, dtype='float32')
  81. inner = Conv2D(conv_filters, kernel_size, padding='same', activation=act, kernel_initializer='he_normal',
  82. name='conv1')(input_data)
  83. inner = MaxPooling2D(pool_size=(pool_size, pool_size), name='max1')(inner)
  84. inner = Conv2D(conv_filters, kernel_size, padding='same', activation=act, kernel_initializer='he_normal',
  85. name='conv2')(inner)
  86. inner = MaxPooling2D(pool_size=(pool_size, pool_size), name='max2')(inner)
  87. conv_to_rnn_dims = (img_w // (pool_size ** 2), (img_h // (pool_size ** 2)) * conv_filters)
  88. inner = Reshape(target_shape=conv_to_rnn_dims, name='reshape')(inner)
  89. 减少输入尺寸到RNN:cuts down input size going into RNN:
  90. inner = Dense(time_dense_size, activation=act, name='dense1')(inner)
  91. GRU模型:两层双向的算法
  92. Two layers of bidirecitonal GRUs
  93. GRU seems to work as well, if not better than LSTM:
  94. gru_1 = GRU(rnn_size, return_sequences=True, kernel_initializer='he_normal', name='gru1')(inner)
  95. gru_1b = GRU(rnn_size, return_sequences=True, go_backwards=True, kernel_initializer='he_normal', name='gru1_b')(inner)
  96. gru1_merged = add([gru_1, gru_1b])
  97. gru_2 = GRU(rnn_size, return_sequences=True, kernel_initializer='he_normal', name='gru2')(gru1_merged)
  98. gru_2b = GRU(rnn_size, return_sequences=True, go_backwards=True, kernel_initializer='he_normal', name='gru2_b')(gru1_merged)
  99. 将RNN输出转换为字符激活:transforms RNN output to character activations
  100. inner = Dense(img_gen.get_output_size(), kernel_initializer='he_normal',
  101. name='dense2')(concatenate([gru_2, gru_2b]))
  102. y_pred = Activation('softmax', name='softmax')(inner)
  103. Model(inputs=input_data, outputs=y_pred).summary()
  104. labels = Input(name='the_labels', shape=[img_gen.absolute_max_string_len], dtype='float32')
  105. input_length = Input(name='input_length', shape=[1], dtype='int64')
  106. label_length = Input(name='label_length', shape=[1], dtype='int64')
  107. Keras目前不支持带有额外参数的loss funcs,所以CTC loss是在lambda层中实现的
  108. Keras doesn't currently support loss funcs with extra parameters, so CTC loss is implemented in a lambda layer
  109. loss_out = Lambda(ctc_lambda_func, output_shape=(1,), name='ctc')([y_pred, labels, input_length, label_length])
  110. clipnorm似乎加快了收敛速度:clipnorm seems to speeds up convergence
  111. sgd = SGD(lr=0.02, decay=1e-6, momentum=0.9, nesterov=True, clipnorm=5)
  112. model = Model(inputs=[input_data, labels, input_length, label_length], outputs=loss_out)
  113. 计算损失发生在其他地方,所以使用一个哑函数来表示损失
  114. the loss calc occurs elsewhere, so use a dummy lambda func for the loss
  115. model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer=sgd)
  116. if start_epoch > 0:
  117. weight_file = os.path.join(OUTPUT_DIR, os.path.join(run_name, 'weights%02d.h5' % (start_epoch - 1)))
  118. model.load_weights(weight_file)
  119. 捕获softmax的输出,以便在可视化过程中解码输出
  120. captures output of softmax so we can decode the output during visualization
  121. test_func = K.function([input_data], [y_pred])
  122. 反馈函数,即运行固定次数后,执行反馈函数可保存模型,并且可视化当前训练的效果
  123. viz_cb = VizCallback(run_name, test_func, img_gen.next_val())
  124. 执行训练:
  125. model.fit_generator(generator=img_gen.next_train(), steps_per_epoch=(words_per_epoch - val_words),
  126. epochs=stop_epoch, validation_data=img_gen.next_val(), validation_steps=val_words,
  127. callbacks=[EarlyStopping(patience=10), viz_cb, img_gen], initial_epoch=start_epoch)

网站声明:如果转载,请联系本站管理员。否则一切后果自行承担。

本文链接:https://www.xckfsq.com/news/show.html?id=3102
赞同 0
评论 0 条
中兴通讯招聘2L0
粉丝 0 发表 16 + 关注 私信
上周热门
如何使用 StarRocks 管理和优化数据湖中的数据?  2941
【软件正版化】软件正版化工作要点  2860
统信UOS试玩黑神话:悟空  2819
信刻光盘安全隔离与信息交换系统  2712
镜舟科技与中启乘数科技达成战略合作,共筑数据服务新生态  1246
grub引导程序无法找到指定设备和分区  1213
华为全联接大会2024丨软通动力分论坛精彩议程抢先看!  163
点击报名 | 京东2025校招进校行程预告  162
2024海洋能源产业融合发展论坛暨博览会同期活动-海洋能源与数字化智能化论坛成功举办  160
华为纯血鸿蒙正式版9月底见!但Mate 70的内情还得接着挖...  157
本周热议
我的信创开放社区兼职赚钱历程 40
今天你签到了吗? 27
信创开放社区邀请他人注册的具体步骤如下 15
如何玩转信创开放社区—从小白进阶到专家 15
方德桌面操作系统 14
我有15积分有什么用? 13
用抖音玩法闯信创开放社区——用平台宣传企业产品服务 13
如何让你先人一步获得悬赏问题信息?(创作者必看) 12
2024中国信创产业发展大会暨中国信息科技创新与应用博览会 9
中央国家机关政府采购中心:应当将CPU、操作系统符合安全可靠测评要求纳入采购需求 8

加入交流群

请使用微信扫一扫!