代码之家  ›  专栏  ›  技术社区  ›  Gilfoyle

张量流:用最小或最大值替换张量中的所有值

  •  0
  • Gilfoyle  · 技术社区  · 6 年前

    我想用最小的项替换张量中的所有值:

    [1,-2,3,4,-4] -> [-4,-4,-4,-4,-4]
    

    现在我在做

    x = tf.random_normal([1,5], mean=0.0, stddev=1.0, dtype=tf.float32)
    y = tf.reduce_min(x) + 0.0*x
    

    有更好的方法吗?

    1 回复  |  直到 6 年前
        1
  •  1
  •   Peter Szoldan    6 年前

    好吧,如果使用 tf.fill() 而不是含蓄的广播加上附加。见下面的基准。此代码(已测试):

    import tensorflow as tf
    
    x = tf.random_normal([1,5], mean=0.0, stddev=1.0, dtype=tf.float32)
    y = tf.fill( tf.shape( x ), tf.reduce_min(x) )
    
    with tf.Session() as sess:
        res = sess.run( [ x, y ] )
        for v in res:
            print( v )
    

    将输出

    [-1.9890205-0.20791222 0.6901897 0.5605381 0.93578804]]
    【1.9890205-1.9890205-1.9890205-1.9890205-1.9890205-1.9890205】】

    根据需要(数字是随机的,但第二行的最小值是在同一形状中重复的第一行。)


    基准

    在我的本地计算机上 tf.填充() GPU上的版本花费了0.305秒,而CPU上的版本花费了1.479秒,张量形状的版本花费了0.191秒,CPU上的版本花费了1.923秒。 [ 10000, 10000 ] .

    这个 tf.填充() 上的原始版本花费0.082秒,而上的原始版本花费0.610秒。 https://colab.research.google.com 在CPU上,在GPU上分别为0.287和0.874秒。

    下面是我用于基准测试的代码:

    import tensorflow as tf
    import time
    
    with tf.device( "/gpu:0"):
        # x and m are in variables and calculated first so that the timing only measures
        # the fill vs. broadcast operation
        x = tf.Variable( tf.random_normal( [ 10000, 10000 ], mean=0.0, stddev=1.0, dtype=tf.float32 ) )
        m = tf.Variable( 0. )
        m_calc_op = tf.assign( m, tf.reduce_min( x ) )
        y1 = tf.fill( tf.shape( x ), m )
        y2 = m + 0.0 * x
    
    with tf.Session() as sess:
        sess.run( tf.global_variables_initializer() )
        sess.run( m_calc_op ) 
        #res = sess.run( [ y1, y2 ] ) # run it once
    
        start = time.clock()
        #for i in xrange( 10 ):
        res = sess.run( [ m, y1 ] )
        end = time.clock()
        print ( end - start, "m=", res[ 0 ] )
    
        start = time.clock()
        #for i in xrange( 10 ):
        res = sess.run( [ m, y2 ] )
        end = time.clock()
        print ( end - start, "m=", res[ 0 ] )
    

    注意,我注释了10倍的重复,因为它开始给出不合理的低值,可能有一些优化,如果输入没有改变,计算不会重新运行。我把张量放大了。