我遇到了同样的问题,每次运行TensorFlow时,TensorFlow都会变慢,在调试时发现了这个问题。以下是我的情况的简短描述,以及我是如何解决的,以供将来参考。希望它能为人们指明正确的方向,为他们节省一些时间。
就我而言,问题主要是我没有利用
feed_dict
在执行时提供网络状态
sess.run()
.相反,我重新申报了
outputs
,
final_state
和
prediction
https://github.com/tensorflow/tensorflow/issues/1439#issuecomment-194405649
# defining the network
lstm_layer = rnn.BasicLSTMCell(num_units, forget_bias=1)
outputs, final_state = rnn.static_rnn(lstm_layer, input, initial_state=rnn_state, dtype='float32')
prediction = tf.nn.softmax(tf.matmul(outputs[-1], out_weights)+out_bias)
for input_data in data_seq:
# redeclaring, stupid stupid...
outputs, final_state = rnn.static_rnn(lstm_layer, input, initial_state=rnn_state, dtype='float32')
prediction = tf.nn.softmax(tf.matmul(outputs[-1], out_weights)+out_bias)
p, rnn_state = sess.run((prediction, final_state), feed_dict={x: input_data})
当然,解决方案是在开始时只声明一次节点,并为新数据提供
。代码从半慢(开始时>15毫秒)到每次迭代都变慢,到在大约1毫秒内执行每次迭代。我的新代码如下所示:
out_weights = tf.Variable(tf.random_normal([num_units, n_classes]), name="out_weights")
out_bias = tf.Variable(tf.random_normal([n_classes]), name="out_bias")
# placeholder for the network state
state_placeholder = tf.placeholder(tf.float32, [2, 1, num_units])
rnn_state = tf.nn.rnn_cell.LSTMStateTuple(state_placeholder[0], state_placeholder[1])
x = tf.placeholder('float', [None, 1, n_input])
input = tf.unstack(x, 1, 1)
# defining the network
lstm_layer = rnn.BasicLSTMCell(num_units, forget_bias=1)
outputs, final_state = rnn.static_rnn(lstm_layer, input, initial_state=rnn_state, dtype='float32')
prediction = tf.nn.softmax(tf.matmul(outputs[-1], out_weights)+out_bias)
# actual network state, which we input with feed_dict
_rnn_state = tf.nn.rnn_cell.LSTMStateTuple(np.zeros((1, num_units), dtype='float32'), np.zeros((1, num_units), dtype='float32'))
it = 0
for input_data in data_seq:
encl_input = [[input_data]]
p, _rnn_state = sess.run((prediction, final_state), feed_dict={x: encl_input, rnn_state: _rnn_state})
print("{} - {}".format(it, p))
it += 1
将声明从for循环中移出也解决了sdrop 2002的问题,即执行切片
outputs[-1]
sess.run()
在for循环内。