代码之家  ›  专栏  ›  技术社区  ›  Omar Aflak

C++用图形反向自动微分

  •  4
  • Omar Aflak  · 技术社区  · 6 年前

    reverse mode automatic differentiation 在C++中。

    这是代码:

    class Var {
        private:
            double value;
            char character;
            std::vector<std::pair<double, const Var*> > children;
    
        public:
            Var(const double& _value=0, const char& _character='_') : value(_value), character(_character) {};
            void set_character(const char& character){ this->character = character; }
    
            // computes the derivative of the current object with respect to 'var'
            double gradient(Var* var) const{
                if(this==var){
                    return 1.0;
                }
    
                double sum=0.0;
                for(auto& pair : children){
                    // std::cout << "(" << this->character << " -> " <<  pair.second->character << ", " << this << " -> " << pair.second << ", weight=" << pair.first << ")" << std::endl;
                    sum += pair.first*pair.second->gradient(var);
                }
                return sum;
            }
    
            friend Var operator+(const Var& l, const Var& r){
                Var result(l.value+r.value);
                result.children.push_back(std::make_pair(1.0, &l));
                result.children.push_back(std::make_pair(1.0, &r));
                return result;
            }
    
            friend Var operator*(const Var& l, const Var& r){
                Var result(l.value*r.value);
                result.children.push_back(std::make_pair(r.value, &l));
                result.children.push_back(std::make_pair(l.value, &r));
                return result;
            }
    
            friend std::ostream& operator<<(std::ostream& os, const Var& var){
                os << var.value;
                return os;
            }
    };
    

    我试着这样运行代码:

    int main(int argc, char const *argv[]) {
        Var x(5,'x'), y(6,'y'), z(7,'z');
    
        Var k = z + x*y;
        k.set_character('k');
    
        std::cout << "k = " << k << std::endl;
        std::cout << "∂k/∂x = " << k.gradient(&x) << std::endl;
        std::cout << "∂k/∂y = " << k.gradient(&y) << std::endl;
        std::cout << "∂k/∂z = " << k.gradient(&z) << std::endl;
    
        return 0;
    }
    

    应建立的计算图如下:

           x(5)   y(6)              z(7)
             \     /                 /
     ∂w/∂x=y  \   /  ∂w/∂y=x        /
               \ /                 /
              w=x*y               /
                 \               /  ∂k/∂z=1
                  \             /
          ∂k/∂w=1  \           /
                    \_________/
                         |
                       k=w+z
    

    ∂k/∂x 例如,我必须乘以边后面的梯度,并对每个边的结果求和。这是递归地通过 double gradient(Var* var) const . 所以我有 ∂k/∂x = ∂k/∂w * ∂w/∂x + ∂k/∂z * ∂z/∂x

    问题

    如果我有中间计算,比如 x*y 这里,出了点问题。什么时候? std::cout

    k = 37
    (k -> z, 0x7ffeb3345740 -> 0x7ffeb3345710, weight=1)
    (k -> _, 0x7ffeb3345740 -> 0x7ffeb3345770, weight=1)
    (_ -> x, 0x7ffeb3345770 -> 0x7ffeb33456b0, weight=0)
    (_ -> y, 0x7ffeb3345770 -> 0x7ffeb33456e0, weight=5)
    ∂k/∂x = 0
    (k -> z, 0x7ffeb3345740 -> 0x7ffeb3345710, weight=1)
    (k -> _, 0x7ffeb3345740 -> 0x7ffeb3345770, weight=1)
    (_ -> x, 0x7ffeb3345770 -> 0x7ffeb33456b0, weight=0)
    (_ -> y, 0x7ffeb3345770 -> 0x7ffeb33456e0, weight=5)
    ∂k/∂y = 5
    (k -> z, 0x7ffeb3345740 -> 0x7ffeb3345710, weight=1)
    (k -> _, 0x7ffeb3345740 -> 0x7ffeb3345770, weight=1)
    (_ -> x, 0x7ffeb3345770 -> 0x7ffeb33456b0, weight=0)
    (_ -> y, 0x7ffeb3345770 -> 0x7ffeb33456e0, weight=5)
    ∂k/∂z = 1
    

    它打印哪个变量连接到哪个变量,然后打印它们的地址,以及连接的权重(应该是梯度)。

    问题是 weight=0 x x*y轴 (我指的是 w 我不知道为什么这个是零,而另一个重量与 y .

    我注意到的另一件事是,如果你换线 operator*

    result.children.push_back(std::make_pair(1.0, &r));
    result.children.push_back(std::make_pair(1.0, &l));
    

    那就是 取消的连接。

    1 回复  |  直到 6 年前
        1
  •  4
  •   Tony Delroy    6 年前

    台词:

    Var k = z + x*y;
    

    电话 operator* ,它返回 Var 临时的,然后用于 r 论证 operator+ ,其中 pair k 子项包括指向临时 已经,但它不再存在。


    Var xy = x * y;
    xy.set_character('*');
    Var k = z + xy;
    k.set_character('k');
    

    …您的程序使用它生成:

    k = 37
    ∂k/∂x = 6
    ∂k/∂y = 5
    ∂k/∂z = 1
    

    按价值 .


    作为发现这些错误的一般提示。。。当您的程序似乎在做一些无法解释的事情(和/或崩溃)时,请尝试在内存错误检测器下运行它,如 valgrind

    ==22137== Invalid read of size 8
    ==22137==    at 0x1090EA: Var::gradient(Var*) const (in /home/median/so/deriv)
    ==22137==    by 0x109109: Var::gradient(Var*) const (in /home/median/so/deriv)
    ==22137==    by 0x108E12: main (in /home/median/so/deriv)
    ==22137==  Address 0x5b82cd0 is 0 bytes inside a block of size 32 free'd
    ==22137==    at 0x4C3123B: operator delete(void*) (in /usr/lib/valgrind/vgpreload_memcheck-amd64-linux.so)
    ==22137==    by 0x109FC1: __gnu_cxx::new_allocator<std::pair<double, Var const*> >::deallocate(std::pair<double, Var const*>*, unsigned long) (in /home/median/so/deriv)
    ==22137==    by 0x109CDD: std::allocator_traits<std::allocator<std::pair<double, Var const*> > >::deallocate(std::allocator<std::pair<double, Var const*> >&, std::pair<double, Var const*>*, unsigned long) (in /home/median/so/deriv)
    ==22137==    by 0x109963: std::_Vector_base<std::pair<double, Var const*>, std::allocator<std::pair<double, Var const*> > >::_M_deallocate(std::pair<double, Var const*>*, unsigned long) (in /home/median/so/deriv)
    ==22137==    by 0x1097BC: std::_Vector_base<std::pair<double, Var const*>, std::allocator<std::pair<double, Var const*> > >::~_Vector_base() (in /home/median/so/deriv)
    ==22137==    by 0x1095EA: std::vector<std::pair<double, Var const*>, std::allocator<std::pair<double, Var const*> > >::~vector() (in /home/median/so/deriv)
    ==22137==    by 0x109161: Var::~Var() (in /home/median/so/deriv)
    ==22137==    by 0x108D95: main (in /home/median/so/deriv)
    ==22137==  Block was alloc'd at
    ==22137==    at 0x4C3017F: operator new(unsigned long) (in /usr/lib/valgrind/vgpreload_memcheck-amd64-linux.so)
    ==22137==    by 0x10A153: __gnu_cxx::new_allocator<std::pair<double, Var const*> >::allocate(unsigned long, void const*) (in /home/median/so/deriv)
    ==22137==    by 0x10A060: std::allocator_traits<std::allocator<std::pair<double, Var const*> > >::allocate(std::allocator<std::pair<double, Var const*> >&, unsigned long) (in /home/median/so/deriv)
    ==22137==    by 0x109F03: std::_Vector_base<std::pair<double, Var const*>, std::allocator<std::pair<double, Var const*> > >::_M_allocate(unsigned long) (in /home/median/so/deriv)
    ==22137==    by 0x109A8D: void std::vector<std::pair<double, Var const*>, std::allocator<std::pair<double, Var const*> > >::_M_realloc_insert<std::pair<double, Var const*> >(__gnu_cxx::__normal_iterator<std::pair<double, Var const*>*, std::vector<std::pair<double, Var const*>, std::allocator<std::pair<double, Var const*> > > >, std::pair<double, Var const*>&&) (in /home/median/so/deriv)
    ==22137==    by 0x1098CF: void std::vector<std::pair<double, Var const*>, std::allocator<std::pair<double, Var const*> > >::emplace_back<std::pair<double, Var const*> >(std::pair<double, Var const*>&&) (in /home/median/so/deriv)
    ==22137==    by 0x10973F: std::vector<std::pair<double, Var const*>, std::allocator<std::pair<double, Var const*> > >::push_back(std::pair<double, Var const*>&&) (in /home/median/so/deriv)
    ==22137==    by 0x109520: operator*(Var const&, Var const&) (in /home/median/so/deriv)
    ==22137==    by 0x108D6F: main (in /home/median/so/deriv)
    

    另一种捕获方法是在析构函数中添加日志记录,以便知道日志记录中提到的对象地址何时不再有效。

    推荐文章