代码之家  ›  专栏  ›  技术社区  ›  Vlad

动态确定在TensorFlow中初始化变量的张量

  •  0
  • Vlad  · 技术社区  · 6 年前

    我想动态地确定一个张量(不是一个初始化值,而是一个张量初始值设定项!)。例如:

    var1 = tf.Variable(tf.random_normal(shape=(2, 2)))
    var2 = tf.Variable(var1.<some method that returns tf.random_normal(shape=(2, 2)>)
    

    Variable.initialized_value() Variable.initial_value

    1 回复  |  直到 6 年前
        1
  •  0
  •   Vlad    6 年前

    我发现了一个有点丑陋的方法来做到这一点,我发布了一个帖子,以防其他人会感兴趣:

    import tensorflow as tf
    import re
    
    var1 = tf.Variable(tf.random_normal(shape=(2, 2)))
    
    shape = var1.initializer.outputs[0].shape
    input_ = var1.initializer.node_def.input[1]
    tensor_name = (input_ if re.search(r'\d+$', input_) is None
                   else '_'.join(input_.split('_')[:-1]))
    
    var2 = tf.Variable(tf.__dict__[tensor_name](shape=shape))
    
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        print(var1.eval())
        print(var2.eval())
    # [[ 0.23845878 -1.1440094 ]
    #  [ 0.593299   -1.1108586 ]]
    # [[ 1.1235769   1.1481414 ]
    #  [ 1.8934027  -0.33171055]]
    

    无论如何,如果有人能提供一种更干净的方法,我会很高兴的。

    编辑

    这仅在初始化张量具有默认名称时有效。

    更新

    import tensorflow as tf
    var1 = tf.Variable(tf.random_normal(shape=(2, 2), mean=1, stddev=3))
    var2 = tf.contrib.copy_graph.copy_variable_to_graph(var1, var1.graph)
    
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        print(var1.eval())
        print(var2.eval())
    # [[1.7009485 0.9412894]
    #  [1.0769905 1.3085879]]
    # [[ 5.8595214  8.652523 ]
    #  [ 1.86671   -3.170361 ]]