代码之家  ›  专栏  ›  技术社区  ›  Aaditya Ura

tensorflow:typeerror:fetch参数none的类型无效<class'none type'>

  •  0
  • Aaditya Ura  · 技术社区  · 6 年前

    我试图运行这个简单的程序来计算梯度,但是我没有得到任何错误:

    import tensorflow as tf
    import numpy as np
    
    batch_size = 5
    dim = 3
    hidden_units = 8
    
    
    sess = tf.Session()
    
    with sess.as_default():
        x = tf.placeholder(dtype=tf.float32, shape=[None, dim], name="x")
        y = tf.placeholder(dtype=tf.int32, shape=[None], name="y")
        w = tf.Variable(initial_value=tf.random_normal(shape=[dim, hidden_units]), name="w")
        b = tf.Variable(initial_value=tf.zeros(shape=[hidden_units]), name="b")
        logits = tf.nn.tanh(tf.matmul(x, w) + b)
    
        cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, y,name="xentropy")
        # define model end
    
    
        # begin training
        optimizer = tf.train.GradientDescentOptimizer(1e-5)
        grads_and_vars = optimizer.compute_gradients(cross_entropy, tf.trainable_variables())
    
        # generate data
        data = np.random.randn(batch_size, dim)
        labels = np.random.randint(0, 10, size=batch_size)
    
        sess.run(tf.initialize_all_variables())
        gradients_and_vars = sess.run(grads_and_vars, feed_dict={x:data, y:labels})
        for g, v in gradients_and_vars:
            if g is not None:
                print "****************this is variable*************"
                print "variable's shape:", v.shape
                print v
                print "****************this is gradient*************"
                print "gradient's shape:", g.shape
                print g
    
    sess.close()
    

    错误:

    ---------------------------------------------------------------------------
    TypeError                                 Traceback (most recent call last)
    <ipython-input-14-8096b2e21e06> in <module>()
         29 
         30     sess.run(tf.initialize_all_variables())
    ---> 31     outnet = sess.run(grads_and_vars, feed_dict={x:data, y:labels})
         32 #     print(gradients_and_vars)
         33 #         if g is not None:
    
    //anaconda/lib/python3.5/site-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
        893     try:
        894       result = self._run(None, fetches, feed_dict, options_ptr,
    --> 895                          run_metadata_ptr)
        896       if run_metadata:
        897         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
    
    //anaconda/lib/python3.5/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
       1107     # Create a fetch handler to take care of the structure of fetches.
       1108     fetch_handler = _FetchHandler(
    -> 1109         self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles)
       1110 
       1111     # Run request and get response.
    
    //anaconda/lib/python3.5/site-packages/tensorflow/python/client/session.py in __init__(self, graph, fetches, feeds, feed_handles)
        411     """
        412     with graph.as_default():
    --> 413       self._fetch_mapper = _FetchMapper.for_fetch(fetches)
        414     self._fetches = []
        415     self._targets = []
    
    //anaconda/lib/python3.5/site-packages/tensorflow/python/client/session.py in for_fetch(fetch)
        231     elif isinstance(fetch, (list, tuple)):
        232       # NOTE(touts): This is also the code path for namedtuples.
    --> 233       return _ListFetchMapper(fetch)
        234     elif isinstance(fetch, dict):
        235       return _DictFetchMapper(fetch)
    
    //anaconda/lib/python3.5/site-packages/tensorflow/python/client/session.py in __init__(self, fetches)
        338     """
        339     self._fetch_type = type(fetches)
    --> 340     self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
        341     self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers)
        342 
    
    //anaconda/lib/python3.5/site-packages/tensorflow/python/client/session.py in <listcomp>(.0)
        338     """
        339     self._fetch_type = type(fetches)
    --> 340     self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
        341     self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers)
        342 
    
    //anaconda/lib/python3.5/site-packages/tensorflow/python/client/session.py in for_fetch(fetch)
        231     elif isinstance(fetch, (list, tuple)):
        232       # NOTE(touts): This is also the code path for namedtuples.
    --> 233       return _ListFetchMapper(fetch)
        234     elif isinstance(fetch, dict):
        235       return _DictFetchMapper(fetch)
    
    //anaconda/lib/python3.5/site-packages/tensorflow/python/client/session.py in __init__(self, fetches)
        338     """
        339     self._fetch_type = type(fetches)
    --> 340     self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
        341     self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers)
        342 
    
    //anaconda/lib/python3.5/site-packages/tensorflow/python/client/session.py in <listcomp>(.0)
        338     """
        339     self._fetch_type = type(fetches)
    --> 340     self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
        341     self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers)
        342 
    
    //anaconda/lib/python3.5/site-packages/tensorflow/python/client/session.py in for_fetch(fetch)
        228     if fetch is None:
        229       raise TypeError('Fetch argument %r has invalid type %r' %
    --> 230                       (fetch, type(fetch)))
        231     elif isinstance(fetch, (list, tuple)):
        232       # NOTE(touts): This is also the code path for namedtuples.
    
    TypeError: Fetch argument None has invalid type <class 'NoneType'>
    

    为什么会出错?版本问题?

    1 回复  |  直到 6 年前
        1
  •  0
  •   Aaditya Ura    6 年前

    Gradients None

     print([v.name for v in tf.all_variables()])
    

    sess.run(tf.initialize_all_variables())
    gradients_and_vars = sess.run([variable for grad,variable in grads_and_vars], feed_dict={x:data, y:labels})
    print(gradients_and_vars)