代码之家  ›  专栏  ›  技术社区  ›  Natjo

在梯度计算中,tensorflow如何处理不可微节点?

  •  2
  • Natjo  · 技术社区  · 6 年前

    我理解了自动微分的概念,但找不到任何解释,例如tensorflow如何计算不可微函数的误差梯度 tf.where 在我失去功能或 tf.cond

    1 回复  |  直到 6 年前
        1
  •  6
  •   javidcf    6 年前

    tf.where ,您有一个具有三个输入的函数,条件 C ,值为真 T 值为假 F Out T F C[0] True Out[0] 来自 T[0] ,它的梯度应该传播回来。另一方面, F[0] Out[1] False ,然后是 F[1] T[1] . 所以,简而言之 T 你应该传播给定的梯度 把它归零 F . 如果你看 the implementation of the gradient of tf.where ( Select operation) ,它确实做到了:

    @ops.RegisterGradient("Select")
    def _SelectGrad(op, grad):
      c = op.inputs[0]
      x = op.inputs[1]
      zeros = array_ops.zeros_like(x)
      return (None, array_ops.where(c, grad, zeros), array_ops.where(
          c, zeros, grad))
    

    tf.cond , the code is a bit more complicated Merge )在不同的上下文中使用,而且 tf.条件 Switch 开关 操作用于每个输入,因此激活的输入(如果条件是 是的 第二个)得到接收的梯度,另一个输入得到“关闭”的梯度(如 None

    推荐文章