代码之家  ›  专栏  ›  技术社区  ›  Roman

是否可以在不扩展TensorFlow中的计算图的情况下手动设置模型参数值?

  •  0
  • Roman  · 技术社区  · 5 年前

    我已经为一个模型创建了一个计算图。它有用于输入的节点( X ),目标( Y ),预测( P ),模型参数( P1, P2, ..., Pn ),以及成本( C ),这应尽量减少。除此之外,我还有代表模型参数成本梯度的节点( G1, G2, ...., Gn

    X 以及一些模型参数的初始值( )我计算了梯度值。

    现在,我想扫描不同的学习率,并为它们计算模型参数的新“候选”值,计算相应的成本,然后接受最佳“候选”参数作为模型参数的新值。

    我尝试过这样做:

    # get the current values of the model parameters
    ini_vals = [s.run(param) for param in params]
    
    # get the current values of the gradients
    grad_vals = [s.run(grad, feed_dict = feed_dict) for grad in grads]
    
    # start from the smallest allowed learning rate
    lr = min_lr
    
    best_score = None
    
    while True:
    
        # get new "candidate" values of model parameters for a given learning rate
        new_vals = [ini_val - lr * grad for ini_val, grad in zip(ini_vals, grad_vals)]
    
        # assign the new values to the corresponding symbolic nodes
        for i in range(len(new_vals)):
            s.run(tf.assign(params[i], new_vals[i]))
    
        # get the corresponding score
        score_val = s.run(score, feed_dict = feed_dict)
    
        # if the current score is better than the previous one, accept it
        if best_score == None or score_val < best_score:
            best_score = score_val
            best_lr = lr
            best_params = new_vals[:]
    
        # if the new score is worse than the previous one, stop the search
        else:
    
            # use the best found parameters as new model paramers
            for i in range(len(new_vals)):
                s.run(tf.assign(params[i], best_params[i]))
    
            break
    
        # increase the learning rate
        lr *= factor
    

    但是,每次后续调用时,此过程都会变得越来越慢。这个问题是由于使用 tf.assign 函数我扩展了计算图(正如我从 this answer ).

    0 回复  |  直到 5 年前
        1
  •  1
  •   Vlad    5 年前

    你可以用 tf.Variable.load() :

    import tensorflow as tf
    import numpy as np
    
    w = tf.Variable(tf.ones((2,)))
    
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        print(sess.run(w)) # [1. 1.]
        w.load(np.zeros((2,)))
        print(sess.run(w)) # [0. 0.]
    
    推荐文章