我正在使用以下占位符训练CNN。。。
# mnist data image of shape [batch_size, 28, 28, 1].
x = tf.placeholder(tf.float32, [None, 28, 28, 1])
# 0-9 digits recognition => 10 classes.
y = tf.placeholder(tf.float32, [None, 10])
我在一个会话中对其进行训练,由于每次在将数据提供给占位符之前,我都必须对数据进行整形,下面是我正在做的大致操作。。。
with tf.Session() as sess:
for epoch in range(25):
total_batch = int(features.train.num_examples/500)
avg_cost = 0
for i in range(total_batch):
batch_xs, batch_ys = features.train.next_batch(10)
_, c = sess.run([train_op, loss], feed_dict={x:batch_xs.reshape([-1, 28, 28, 1]), y:batch_ys})
...
...
logits
层
# Logits Layer.
# Create a dense layer with 10 neurons => 10 classes
# Output has a shape of [batch_size, 10]
logits = tf.layers.dense(inputs=dropout, units=10)
# Softmax layer for deriving probabilities.
pred = tf.nn.softmax(logits, name="softmax_tensor")
...所以在所有的训练之后,这就是我计算预测的方式。。。
from tensorflow.examples.tutorials.mnist import input_data
...
...
features = input_data.read_data_sets("/tmp/data/", one_hot=True)
...
...
# list of booleans to determine the correct predictions
correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
print(correct_prediction.eval({x:features.test.images.reshape([10000, 28, 28, 1]), y:features.test.labels}))
不幸的是,我必须转换每一个
features
将数据转换为正确的格式,因为占位符
x
仅接受格式
[batch_size, 28, 28, 1]
有什么更好的方法不会导致我的电脑崩溃?我在其他神经网络中使用了以下代码。。。
print(correct_prediction.eval({x:features.test.images, y:features.test.labels}))
...但我从来没有遇到过这样的问题,让我相信原因来自
reshape
作用