DL之CNN:自定义SimpleConvNet【3层,im2col优化】利用mnist数据集实现手写数字识别多分类训练来评估模型
目录
- class Convolution:
- def __init__(self, W, b, stride=1, pad=0):
-
- ……
-
-
- def forward(self, x):
- FN, C, FH, FW = self.W.shape
- N, C, H, W = x.shape
- out_h = 1 + int((H + 2*self.pad - FH) / self.stride)
- out_w = 1 + int((W + 2*self.pad - FW) / self.stride)
-
- col = im2col(x, FH, FW, self.stride, self.pad)
- col_W = self.W.reshape(FN, -1).T
-
-
- out = np.dot(col, col_W) + self.b
- out = out.reshape(N, out_h, out_w, -1).transpose(0, 3, 1, 2)
-
- self.x = x
- self.col = col
- self.col_W = col_W
-
- return out
-
- def backward(self, dout):
- FN, C, FH, FW = self.W.shape
- dout = dout.transpose(0,2,3,1).reshape(-1, FN)
-
- self.db = np.sum(dout, axis=0)
- self.dW = np.dot(self.col.T, dout)
- self.dW = self.dW.transpose(1, 0).reshape(FN, C, FH, FW)
-
- dcol = np.dot(dout, self.col_W.T)
-
- return dx
-
-
- class Pooling:
- def __init__(self, pool_h, pool_w, stride=1, pad=0):
-
- self.pool_h = pool_h
- self.pool_w = pool_w
- self.stride = stride
- self.pad = pad
-
- self.x = None
- self.arg_max = None
-
- ……
-
-
-
- class SimpleConvNet:
- def __init__(self, input_dim=(1, 28, 28),
- conv_param={'filter_num':30, 'filter_size':5, 'pad':0, 'stride':1},
- hidden_size=100, output_size=10, weight_init_std=0.01):
-
- filter_num = conv_param['filter_num']
- filter_size = conv_param['filter_size']
- filter_pad = conv_param['pad']
- filter_stride = conv_param['stride']
- input_size = input_dim[1]
- conv_output_size = (input_size - filter_size + 2*filter_pad) / filter_stride + 1
- pool_output_size = int(filter_num * (conv_output_size/2) * (conv_output_size/2))
-
- self.params = {}
- self.params['W1'] = weight_init_std * \
- np.random.randn(filter_num, input_dim[0], filter_size, filter_size)
- self.params['b1'] = np.zeros(filter_num)
- self.params['W2'] = weight_init_std * \
- np.random.randn(pool_output_size, hidden_size)
- self.params['b2'] = np.zeros(hidden_size)
- self.params['W3'] = weight_init_std * \
- np.random.randn(hidden_size, output_size)
- self.params['b3'] = np.zeros(output_size)
-
- self.layers = OrderedDict()
- self.layers['Conv1'] = Convolution(self.params['W1'], self.params['b1'],
- conv_param['stride'], conv_param['pad'])
- self.layers['Relu1'] = Relu()
- self.layers['Pool1'] = Pooling(pool_h=2, pool_w=2, stride=2)
- self.layers['Affine1'] = Affine(self.params['W2'], self.params['b2'])
- self.layers['Relu2'] = Relu()
- self.layers['Affine2'] = Affine(self.params['W3'], self.params['b3'])
-
- self.last_layer = SoftmaxWithLoss()
-
- ……
-
- def save_params(self, file_name="params.pkl"):
- params = {}
- for key, val in self.params.items():
- params[key] = val
- with open(file_name, 'wb') as f:
- pickle.dump(params, f)
-
- def load_params(self, file_name="params.pkl"):
- with open(file_name, 'rb') as f:
- params = pickle.load(f)
- for key, val in params.items():
- self.params[key] = val
-
- for i, key in enumerate(['Conv1', 'Affine1', 'Affine2']):
- self.layers[key].W = self.params['W' + str(i+1)]
- self.layers[key].b = self.params['b' + str(i+1)]
- train_loss:2.29956519109714
- === epoch:1, train_acc:0.216, test_acc:0.218 ===
- train_loss:2.2975110344641716
- train_loss:2.291654113382576
- train_loss:2.2858174689127875
- train_loss:2.272262093336837
- train_loss:2.267908303517325
- train_loss:2.2584119706864336
- train_loss:2.2258807222804693
- train_loss:2.2111025085252543
- train_loss:2.188119055308738
- train_loss:2.163215575430596
- train_loss:2.1191887076886724
- train_loss:2.0542599060672186
- train_loss:2.0244523646451915
- train_loss:1.9779786923239808
- train_loss:1.9248431928319325
- train_loss:1.7920653808470397
- train_loss:1.726860911000866
- train_loss:1.7075144252509131
- train_loss:1.6875413868425186
- train_loss:1.6347461097804266
- train_loss:1.5437112361395253
- train_loss:1.4987893515035628
- train_loss:1.3856720782969847
- train_loss:1.2002110952243676
- train_loss:1.2731100379603273
- train_loss:1.117132621224333
- train_loss:1.0622583460165833
- train_loss:1.0960592785565957
- train_loss:0.8692067763172185
- train_loss:0.8548780420217317
- train_loss:0.83872966253374
- train_loss:0.7819342397053507
- train_loss:0.7589812430284729
- train_loss:0.7955332004991336
- train_loss:0.8190930469691535
- train_loss:0.6297212128196131
- train_loss:0.8279837022068413
- train_loss:0.6996430264702379
- train_loss:0.5256550729087258
- train_loss:0.7288553394002595
- train_loss:0.7033049908220391
- train_loss:0.5679669207218877
- train_loss:0.6344174262581003
- train_loss:0.7151382401438272
- train_loss:0.5814593192354963
- train_loss:0.5736217677325146
- train_loss:0.5673622947809682
- train_loss:0.48303413903204395
- train_loss:0.452267909884157
- train_loss:0.4009118158839013
- === epoch:2, train_acc:0.818, test_acc:0.806 ===
- train_loss:0.5669686001623327
- train_loss:0.5358187806595359
- train_loss:0.3837535143737321
- train_loss:0.544335563142595
- train_loss:0.39288485196871803
- train_loss:0.49770310644457566
- train_loss:0.4610248131112265
- train_loss:0.36641463191798196
- train_loss:0.4874682221372042
- train_loss:0.38796698110644817
- train_loss:0.3620230776259665
- train_loss:0.4744726274001774
- train_loss:0.3086952062454927
- train_loss:0.40012397040718645
- train_loss:0.3634667070910744
- train_loss:0.3204093812396573
- train_loss:0.5063082359543781
- train_loss:0.5624992123039615
- train_loss:0.34281562891324663
- train_loss:0.3415065217065326
- train_loss:0.4946703009790488
- train_loss:0.48942997572068253
- train_loss:0.25416776815225534
- train_loss:0.3808555005314615
- train_loss:0.22793380858862108
- train_loss:0.4709915396804245
- train_loss:0.25826190862498605
- train_loss:0.44862426522901516
- train_loss:0.25519522472564815
- train_loss:0.5063495442657376
- train_loss:0.37233317168099206
- train_loss:0.4027673899570495
- train_loss:0.4234905061164214
- train_loss:0.44590221111177714
- train_loss:0.3846538639824134
- train_loss:0.3371733857576183
- train_loss:0.23612786737321756
- train_loss:0.4814543539448962
- train_loss:0.38362762929477556
- train_loss:0.5105811329813293
- train_loss:0.31729857191880056
- train_loss:0.43677582454472663
- train_loss:0.37362647454980324
- train_loss:0.2696715797445873
- train_loss:0.26682852302518134
- train_loss:0.18763432881504752
- train_loss:0.2886557425885745
- train_loss:0.23833327847639763
- train_loss:0.36315802981646
- train_loss:0.21083779781027828
- === epoch:3, train_acc:0.89, test_acc:0.867 ===
- train_loss:0.34070333399972674
- train_loss:0.3356587138064409
- train_loss:0.25919406618960505
- train_loss:0.31537349840856743
- train_loss:0.2276928810208216
- train_loss:0.32171416950979326
- train_loss:0.22754919179736025
- train_loss:0.37619164258262944
- train_loss:0.3221102374023198
- train_loss:0.36724681541104537
- train_loss:0.3310213819075522
- train_loss:0.33583429981768936
- train_loss:0.36054827740285833
- train_loss:0.3002031789326344
- train_loss:0.19480027104864756
- train_loss:0.3074748184113467
- train_loss:0.31035699050378
- train_loss:0.37289594799797554
- train_loss:0.38054981033442864
- train_loss:0.2150866558286973
- train_loss:0.4014488874986493
- train_loss:0.2643304660197891
- train_loss:0.31806887985854354
- train_loss:0.29365139713396693
- train_loss:0.33212651106203267
- train_loss:0.29544164636048587
- train_loss:0.4969991428069569
- train_loss:0.3348535409949116
- train_loss:0.18914984777413654
- train_loss:0.3868380951987871
- train_loss:0.26857192970788485
- train_loss:0.373151707743815
- train_loss:0.3522570704735893
- train_loss:0.204823140388568
- train_loss:0.3974239710544049
- train_loss:0.21753509102652058
- train_loss:0.26034229667679715
- train_loss:0.26991319118062235
- train_loss:0.30959776720795107
- train_loss:0.2718109180045845
- train_loss:0.2738413103423023
- train_loss:0.22209179719364106
- train_loss:0.5025051167945939
- train_loss:0.23308114849307443
- train_loss:0.24989561030033144
- train_loss:0.4666621160650158
- train_loss:0.3511547384608582
- train_loss:0.32856542443039893
- train_loss:0.29344954251556093
网站声明:如果转载,请联系本站管理员。否则一切后果自行承担。
加入交流群
请使用微信扫一扫!