MNIST的结果是0-9,常用softmax函数进行分类,输出结果。

softmax函数常用于分类,定义如下:

​ $$softmax(x_i)=\frac{exp(x_i)}{\sum_jexp(x)}$$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295

# coding: utf-8

import tensorflow as tf

from tensorflow.examples.tutorials.mnist import input_data



#载入数据集

mnist = input_data.read_data_sets("MNIST_data",one_hot=True)



#每个批次的大小

batch_size = 100

#计算一共有多少个批次

n_batch = mnist.train.num_examples // batch_size

print( mnist.train.num_examples)

#定义两个placeholder

x = tf.placeholder(tf.float32,[None,784])#输入

y = tf.placeholder(tf.float32,[None,10])#输出



#创建一个简单的神经网络

W = tf.Variable(tf.zeros([784,10]))

b = tf.Variable(tf.zeros([10]))

prediction = tf.nn.softmax(tf.matmul(x,W)+b)



#二次代价函数

loss = tf.reduce_mean(tf.square(y-prediction))

#使用梯度下降法

train_step = tf.train.GradientDescentOptimizer(3).minimize(loss)

#train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss)

#初始化变量

init = tf.global_variables_initializer()



#结果存放在一个布尔型列表中

correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))#argmax返回一维张量中最大的值所在的位置

#求准确率

accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))#布尔转成float32,然后求平均值



with tf.Session() as sess:

sess.run(init)

for epoch in range(100):

for batch in range(n_batch):

batch_xs,batch_ys = mnist.train.next_batch(batch_size)

sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys})



acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})

print("Iter " + str(epoch) + ",Testing Accuracy " + str(acc))

if acc > 0.99:

break



#输出

# Iter 0,Testing Accuracy 0.9074

# Iter 1,Testing Accuracy 0.9168

# Iter 2,Testing Accuracy 0.9221

# Iter 3,Testing Accuracy 0.9223

# Iter 4,Testing Accuracy 0.9237

# Iter 5,Testing Accuracy 0.9248

# Iter 6,Testing Accuracy 0.9246

# Iter 7,Testing Accuracy 0.925

# Iter 8,Testing Accuracy 0.9257

# Iter 9,Testing Accuracy 0.9278

# Iter 10,Testing Accuracy 0.9268

# Iter 11,Testing Accuracy 0.928

# Iter 12,Testing Accuracy 0.9269

# Iter 13,Testing Accuracy 0.9288

# Iter 14,Testing Accuracy 0.9273

# Iter 15,Testing Accuracy 0.9282

# Iter 16,Testing Accuracy 0.9301

# Iter 17,Testing Accuracy 0.9287

# Iter 18,Testing Accuracy 0.9297

# Iter 19,Testing Accuracy 0.9296

# Iter 20,Testing Accuracy 0.9285

# Iter 21,Testing Accuracy 0.9286

# Iter 22,Testing Accuracy 0.9288

# Iter 23,Testing Accuracy 0.9285

# Iter 24,Testing Accuracy 0.9311

# Iter 25,Testing Accuracy 0.9298

# Iter 26,Testing Accuracy 0.9294

# Iter 27,Testing Accuracy 0.9299

# Iter 28,Testing Accuracy 0.9298

# Iter 29,Testing Accuracy 0.9298

# Iter 30,Testing Accuracy 0.9307

# Iter 31,Testing Accuracy 0.9305

# Iter 32,Testing Accuracy 0.9291

# Iter 33,Testing Accuracy 0.9295

# Iter 34,Testing Accuracy 0.9289

# Iter 35,Testing Accuracy 0.9301

# Iter 36,Testing Accuracy 0.93

# Iter 37,Testing Accuracy 0.9293

# Iter 38,Testing Accuracy 0.93

# Iter 39,Testing Accuracy 0.9298

# Iter 40,Testing Accuracy 0.9299

# Iter 41,Testing Accuracy 0.9304

# Iter 42,Testing Accuracy 0.9303

# Iter 43,Testing Accuracy 0.93

# Iter 44,Testing Accuracy 0.9308

# Iter 45,Testing Accuracy 0.9296

# Iter 46,Testing Accuracy 0.9291

# Iter 47,Testing Accuracy 0.9306

# Iter 48,Testing Accuracy 0.9311

# Iter 49,Testing Accuracy 0.9301

# Iter 50,Testing Accuracy 0.93

# Iter 51,Testing Accuracy 0.9301

# Iter 52,Testing Accuracy 0.9306

# Iter 53,Testing Accuracy 0.9303

# Iter 54,Testing Accuracy 0.9307

# Iter 55,Testing Accuracy 0.9295

# Iter 56,Testing Accuracy 0.9313

# Iter 57,Testing Accuracy 0.9295

# Iter 58,Testing Accuracy 0.9303

# Iter 59,Testing Accuracy 0.9299

# Iter 60,Testing Accuracy 0.9286

# Iter 61,Testing Accuracy 0.9301

# Iter 62,Testing Accuracy 0.9303

# Iter 63,Testing Accuracy 0.9289

# Iter 64,Testing Accuracy 0.9301

# Iter 65,Testing Accuracy 0.9296

# Iter 66,Testing Accuracy 0.9303

# Iter 67,Testing Accuracy 0.9313

# Iter 68,Testing Accuracy 0.9301

# Iter 69,Testing Accuracy 0.9312

# Iter 70,Testing Accuracy 0.9294

# Iter 71,Testing Accuracy 0.9283

# Iter 72,Testing Accuracy 0.9295

# Iter 73,Testing Accuracy 0.9305

# Iter 74,Testing Accuracy 0.929

# Iter 75,Testing Accuracy 0.9315

# Iter 76,Testing Accuracy 0.9306

# Iter 77,Testing Accuracy 0.9288

# Iter 78,Testing Accuracy 0.9312

# Iter 79,Testing Accuracy 0.9309

# Iter 80,Testing Accuracy 0.9298

# Iter 81,Testing Accuracy 0.9293

# Iter 82,Testing Accuracy 0.9295

# Iter 83,Testing Accuracy 0.9292

# Iter 84,Testing Accuracy 0.9291

# Iter 85,Testing Accuracy 0.9294

# Iter 86,Testing Accuracy 0.9298

# Iter 87,Testing Accuracy 0.9296

# Iter 88,Testing Accuracy 0.9301

# Iter 89,Testing Accuracy 0.9306

# Iter 90,Testing Accuracy 0.9297

# Iter 91,Testing Accuracy 0.9307

# Iter 92,Testing Accuracy 0.9289

# Iter 93,Testing Accuracy 0.931

# Iter 94,Testing Accuracy 0.9301

# Iter 95,Testing Accuracy 0.9302

# Iter 96,Testing Accuracy 0.9297

# Iter 97,Testing Accuracy 0.9299

# Iter 98,Testing Accuracy 0.9317

# Iter 99,Testing Accuracy 0.9297