代码之家  ›  专栏  ›  技术社区  ›  Jame

如何将此代码矢量化?

  •  12
  • Jame  · 技术社区  · 7 年前

    我已经写了一个递归函数,但是,它需要很多时间。因此,我将其矢量化,但它不会产生与递归函数相同的结果。这是我的非矢量化代码:

    function visited = procedure_explore( u, adj_mat, visited )
    visited(u) = 1;
    neighbours = find(adj_mat(u,:));
    for ii = 1:length(neighbours)
        if (visited(neighbours(ii)) == 0)
            visited = procedure_explore( neighbours(ii), adj_mat, visited );
        end
    end
    end
    

    这是我的矢量化代码:

    function visited = procedure_explore_vec( u, adj_mat, visited )
    visited(u) = 1;
    neighbours = find(adj_mat(u,:));
    len_neighbours=length(neighbours);
    visited_neighbours_zero=visited(neighbours(1:len_neighbours)) == 0;
    if(~isempty(visited_neighbours_zero))
        visited = procedure_explore_vec( neighbours(visited_neighbours_zero), adj_mat, visited );
    end
    end
    

    function main
        adj_mat=[0 0 0 0;
                 1 0 1 1;
                 1 0 0 0;
                 1 0 0 1];
        u=2;
        visited=zeros(size(adj_mat,1));
        tic
        visited = procedure_explore( u, adj_mat, visited )
        toc
        visited=zeros(size(adj_mat,1));
        tic
        visited = procedure_explore_vec( u, adj_mat, visited )
        toc
    end
    

    这是我试图实现的算法: enter image description here

    如果无法进行矢量化,则mex解决方案也会很好。

    更新基准: 该基准基于MATLAB 2017a。结果表明,原始代码比其他方法更快

    Speed up between original and logical methods is 0.39672
    Speed up between original and nearest methods is 0.0042583
    

    完整代码

    function main_recersive
        adj_mat=[0 0 0 0;
                 1 0 1 1;
                 1 0 0 0;
                 1 0 0 1];
        u=2;
        visited=zeros(size(adj_mat,1));
        f_original=@()(procedure_explore( u, adj_mat, visited ));
        t_original=timeit(f_original);
    
        f_logical=@()(procedure_explore_logical( u, adj_mat ));
        t_logical=timeit(f_logical);
    
        f_nearest=@()(procedure_explore_nearest( u, adj_mat,visited ));
        t_nearest=timeit(f_nearest);
    
        disp(['Speed up between original and logical methods is ',num2str(t_original/t_logical)])
        disp(['Speed up between original and nearest methods is ',num2str(t_original/t_nearest)])    
    
    end
    
    function visited = procedure_explore( u, adj_mat, visited )
        visited(u) = 1;
        neighbours = find(adj_mat(u,:));
        for ii = 1:length(neighbours)
            if (visited(neighbours(ii)) == 0)
                visited = procedure_explore( neighbours(ii), adj_mat, visited );
            end
        end
    end
    
    function visited = procedure_explore_nearest( u, adj_mat, visited )
        % add u since your function also includes it.
        nodeIDs = [nearest(digraph(adj_mat),u,inf) ; u];
        % transform to output format of your function
        visited = zeros(size(adj_mat,1));
        visited(nodeIDs) = 1;
    
    end 
    
    function visited = procedure_explore_logical( u, adj_mat )
       visited = false(1, size(adj_mat, 1));
       visited(u) = true;
       new_visited = visited;
       while any(new_visited)
          visited = any([visited; new_visited], 1);
          new_visited = any(adj_mat(new_visited, :), 1);
          new_visited = and(new_visited, ~visited);
       end
    end
    
    6 回复  |  直到 7 年前
        1
  •  4
  •   beaker    7 年前

    这里有一个有趣的小函数,它在图上执行非递归宽度优先搜索。

    function visited = procedure_explore_logical( u, adj_mat )
       visited = false(1, size(adj_mat, 1));
       visited(u) = true;
       new_visited = visited;
    
       while any(new_visited)
          visited = any([visited; new_visited], 1);
          new_visited = any(adj_mat(new_visited, :), 1);
          new_visited = and(new_visited, ~visited);
       end
    end
    

    在倍频程中,它的运行速度大约是100x100邻接矩阵上递归版本的50倍。你必须在MATLAB上对其进行基准测试,看看你得到了什么。

        2
  •  2
  •   DasKrümelmonster    7 年前

    在具有n个节点的图中,最长路径不能长于n-1,因此可以将所有功率相加以进行可达性分析:

    adj_mat + adj_mat^2 + adj_mat^3
    ans =
       0   0   0   0
       4   0   1   3
       1   0   0   0
       3   0   0   3
    

    visited(v) = ans(v, :) > 0;
    

    根据您的定义,您可能需要更改结果中的列和行(即,取ans(:,v))。

    为了提高性能,您可以使用较低的功率来制作较高的功率。例如,可以有效地计算A+A^2+A^3+A^4+A^5:

    A2 = A^2;
    A3 = A2*A
    A4 = A2^2;
    A5 = A4*A;
    allWalks= A + A2 + A3 + A4 + A5;
    

    注:

    这样可以最大限度地减少矩阵乘法的次数,而且MATLAB执行矩阵平方运算的速度可能会快于常规乘法。

    根据我的经验,矩阵乘法在MATLAB中相对较快,这将一次生成图中所有节点的结果(可达性)向量。如果您只对大型图的一个子集感兴趣,那么这可能不是最佳解决方案。

    另请参见以下答案: https://stackoverflow.com/a/7276595/1974021

        3
  •  1
  •   Leander Moesinger    7 年前

    也就是说,如果没有循环或递归调用,通常不可能找到所有可到达的节点。例如,您可以检查 (有效或无效)路径。但是,这与您的功能有很大不同,并且根据节点的数量,可能会由于要检查的路径数量惊人而导致性能损失。您当前的功能还不错,可以很好地适应大型网络。

    有点离题,但由于Matlab 2016a,您可以使用 nearest() 查找所有可到达的节点(不包括起始节点)。与深度优先算法相比,它调用了广度优先算法:

    % add u since your function also includes it.
    nodeIDs = [nearest(digraph(adj_mat),u,inf) ; u]; 
    
    % transform to output format of your function
    visited = zeros(size(adj_mat,1));
    visited(nodeIDs) = 1;
    

    如果这是一个学生的项目,你可以争辩说,当你的函数工作时,出于性能原因,你使用了内置函数。

        4
  •  1
  •   rahnema1    7 年前

    visited(u) = 1; visited 在函数体中,不制作其副本,但在修改时,将创建其副本,并对其副本进行修改。为了防止这种情况,您可以使用 handle object

    visited_class.m ):

    classdef visited_class < handle
        properties
            visited
        end
        methods
            function obj = visited_class(adj_mat)
                obj.visited = zeros(1, size(adj_mat,1));
            end
        end
    end
    

    function procedure_explore_handle( u, adj_mat,visited_handle )
        visited_handle.visited(u) = 1;
        neighbours = find(adj_mat(u,:));
        for n = neighbours
            if (visited_handle.visited(n) == 0)
                procedure_explore_handle( n, adj_mat , visited_handle );
            end
        end
    end
    

    初始化变量:

    adj_mat=[0 0 0 0;
             1 0 1 1;
             1 0 0 0;
             1 0 0 1];
    visited_handle = visited_class(adj_mat);
    u = 2;
    

    procedure_explore_handle( u, adj_mat,visited_handle );
    

    结果保存到 visited_handle

    disp(visited_handle.visited)
    
        5
  •  0
  •   alle_meije    7 年前

    如果你想从图中的一个点转到另一个点,从资源角度来看,最有效的方法是Dijkstra算法。Floyd-Warshall算法计算所有点之间的所有距离,并且可以并行(从多个点开始)。

    adj_mat2=adj_mat^2;               % allowed to use 2 steps
    while (adj_mat2 ~= adj_mat)       % check if new points were reached
          adj_mat=adj_mat2;           % current set of reachable points
          adj_mat2=(adj_mat^2)>0;     % allowed all steps again: power method
    end
    
        6
  •  0
  •   Lior    7 年前

    这个答案只是给出了来自 DasKrümelmonster's answer ,我认为这比问题中的代码要快(至少在矩阵维数不太大的情况下)。它使用 polyvalm

    function visited = procedure_explore_vec(u, adj_mat)
        connectivity_matrix = polyvalm(ones(size(adj_mat,1),1),adj_mat)>0;
    
        visited = connectivity_matrix(u,:);
    end