代码之家  ›  专栏  ›  技术社区  ›  mon

反向传播(Andrew Ng的库塞尔ML)梯度下降澄清

  •  0
  • mon  · 技术社区  · 4 年前

    Coursera ML Week 4 assignment

    % Calculate the gradients of Weight2
    % Derivative at Loss function J=L(Z) : dJ/dZ = (oi-yi)/oi(1-oi)
    % Derivative at Sigmoid activation function dZ/dY = oi(1-oi)
    
    delta_theta2 = oi - yi;  % <--- (dJ/dZ) * (dZ/dY) 
    
    # Using +/plus NOT -/minus
    Theta2_grad = Theta2_grad +     <-------- Why plus(+)?
                  bsxfun(@times, hi, transpose(delta_theta2)); 
    

    代码摘录

    for i = 1:m  
        % i is training set index of X (including bias). X(i, :) is 401 data.
        xi = X(i, :);
        yi = Y(i, :);
        
        % hi is the i th output of the hidden layer. H(i, :) is 26 data.
        hi = H(i, :);
        
        % oi is the i th output layer. O(i, :) is 10 data.
        oi = O(i, :);
        
        %------------------------------------------------------------------------
        % Calculate the gradients of Theta2
        %------------------------------------------------------------------------
        delta_theta2 = oi - yi;
        Theta2_grad = Theta2_grad + bsxfun(@times, hi, transpose(delta_theta2));
     
        %------------------------------------------------------------------------
        % Calculate the gradients of Theta1
        %------------------------------------------------------------------------
        % Derivative of g(z): g'(z)=g(z)(1-g(z)) where g(z) is sigmoid(H_NET).
        dgz = (hi .* (1 - hi));
        delta_theta1 = dgz .* sum(bsxfun(@times, Theta2, transpose(delta_theta2)));
        % There is no input into H0, hence there is no theta for H0. Remove H0.
        delta_theta1 = delta_theta1(2:end);
        Theta1_grad = Theta1_grad + bsxfun(@times, xi, transpose(delta_theta1));
    end
    

    我以为是减去导数。

    enter image description here

    1 回复  |  直到 4 年前
        1
  •  1
  •   ntlarry    4 年前

    由于梯度是通过平均所有训练示例的梯度来计算的,因此我们首先“累积”梯度,同时在所有训练示例上循环。我们通过对所有训练示例的梯度求和来实现这一点。所以用加号高亮显示的线不是渐变更新步骤。(注意alpha也不在那里)它可能在别的地方。它很可能在1到m的环外。

    另外,我不确定您何时会了解到这一点(我确定它在课程的某个地方),但您也可以将代码矢量化:)