代码之家  ›  专栏  ›  技术社区  ›  politinsa

如何使np.where更有效地处理三角矩阵?

  •  1
  • politinsa  · 技术社区  · 6 年前

    我得到了这个代码,距离是一个下三角矩阵,定义如下:

    distance = np.tril(scipy.spatial.distance.cdist(points, points))  
    def make_them_touch(distance):
        """
        Return the every distance where two points touched each other. See example below.
        """
        thresholds = np.unique(distance)[1:] # to avoid 0 at the beginning, not taking a lot of time at all
        result = dict()
        for t in thresholds:
                x, y = np.where(distance == t)
                result[t] = [i for i in zip(x,y)]
        return result
    

    我的问题是大矩阵的np.where很慢(例如2000*100)。
    如何通过改进np.where或更改算法来加快此代码的速度?

    编辑: 作为 MaxU 指出,这里最好的优化不是生成平方矩阵和使用迭代器。

    例子:

    points = np.array([                                                                        
    ...: [0,0,0,0],                                                            
    ...: [1,1,1,1],         
    ...: [3,3,3,3],              
    ...: [6,6,6,6]                             
    ...: ])  
    
    In [106]: distance = np.tril(scipy.spatial.distance.cdist(points, points))
    
    In [107]: distance
    Out[107]: 
    array([[ 0.,  0.,  0.,  0.],
       [ 2.,  0.,  0.,  0.],
       [ 6.,  4.,  0.,  0.],
       [12., 10.,  6.,  0.]])
    
    In [108]: make_them_touch(distance)
    Out[108]: 
    {2.0: [(1, 0)],
     4.0: [(2, 1)],
     6.0: [(2, 0), (3, 2)],
     10.0: [(3, 1)],
     12.0: [(3, 0)]}
    
    1 回复  |  直到 6 年前
        1
  •  1
  •   MaxU - stand with Ukraine    6 年前

    更新1: 以下是 上面的 三角形距离矩阵(因为距离矩阵总是对称的,所以这并不重要):

    from itertools import combinations
    
    res = {tup[0]:tup[1] for tup in zip(pdist(points), list(combinations(range(len(points)), 2)))}
    

    结果:

    In [111]: res
    Out[111]:
    {1.4142135623730951: (0, 1),
     4.69041575982343: (0, 2),
     4.898979485566356: (1, 2)}
    

    更新2: 此版本将支持远距离复制:

    In [164]: import pandas as pd
    

    首先我们建造一只熊猫。系列:

    In [165]: s = pd.Series(list(combinations(range(len(points)), 2)), index=pdist(points))
    
    In [166]: s
    Out[166]:
    2.0     (0, 1)
    6.0     (0, 2)
    12.0    (0, 3)
    4.0     (1, 2)
    10.0    (1, 3)
    6.0     (2, 3)
    dtype: object
    

    现在我们可以按索引分组并生成坐标列表:

    In [167]: s.groupby(s.index).apply(list)
    Out[167]:
    2.0             [(0, 1)]
    4.0             [(1, 2)]
    6.0     [(0, 2), (2, 3)]
    10.0            [(1, 3)]
    12.0            [(0, 3)]
    dtype: object
    

    ps这里的主要思想是,如果要在之后将其展平并消除重复项,就不应该构建平方距离矩阵。