代码之家  ›  专栏  ›  技术社区  ›  Gilfoyle

删除存储为数组的图像中的像素

  •  0
  • Gilfoyle  · 技术社区  · 6 年前

    我有一个核阵列 I 哪些商店 N 大小图像 P (像素数)。每个图像都有大小 P = q*q .

    N = 1000 # number of images
    q = 10 # length and width of image
    P = q*q # pixels of image
    I = np.ones((N,P)) # array with N images of size P
    

    现在我想删除大小的补丁 ps 围绕选定索引 IDX (将所有值设置为零)。

    ps = 2 # patch size (ps x ps)
    IDX = np.random.randint(0,P,(N,1))
    

    我的方法是使用 reshape(q,q) 删除周围的像素 小精灵 . 这里我有一个问题,就是我不知道如何计算给定图像中的位置 小精灵 . 此外,我必须检查索引是否不在图像之外。

    如何解决这个问题,有没有办法把这个过程矢量化?

    编辑:

    在@brenla的帮助下,我做了以下的工作来移除补丁。我的方法的问题是,它需要三个for循环,并且我必须重新塑造每个图像两次。有没有办法提高性能?这一部分大大降低了我的代码速度。

    import numpy as np
    import matplotlib.pyplot as plt
    
    def myplot(I):
        imgs = 10
        for i in range(imgs**2):
            plt.subplot(imgs,imgs,(i+1))
            plt.imshow(I[i].reshape(q,q), interpolation="none")
            plt.axis("off")
        plt.show()
    
    N = 10000
    q = 28
    P = q*q
    I = np.random.rand(N,P)
    
    ps = 3
    IDX = np.random.randint(0,P,(N,1))
    
    for i in range(N):
        img = I[i].reshape(q,q)
        y0, x0 = np.unravel_index(IDX[i,0],(q,q))
        for x in range(ps):
            for y in range(ps):
                if (x0+x < q) and (y0+y < q):
                    img[x0+x,y0+y] = 2.0
        I[i] = img.reshape(1,q*q)
    
    myplot(I)
    
    1 回复  |  直到 6 年前
        1
  •  1
  •   Brenlla    6 年前

    是的,这是可以做到的,但是需要大量使用 np.broadcasting .

    生成数据和 I :

    import time
    
    N = 10000
    q = 28
    P = q*q
    ps = 3 
    I = np.random.rand(N,P)
    IDX = np.random.randint(0,P,(N,1))
    I_copy = I.copy()
    

    现在运行循环解决方案。我切换 x0 y0 :

    t0=time.clock()
    for i in range(N):
        img = I[i].reshape(q,q)
        x0, y0 = np.unravel_index(IDX[i,0],(q,q))
        for x in range(ps):
            for y in range(ps):
                if (x0+x < q) and (y0+y < q):
                    img[x0+x,y0+y] = 2.0
        I[i] = img.reshape(1,q*q)
    print('With loop: {:.2f} ms'.format(time.clock()*1e3-t0*1e3))
    

    在我的机器上大约276毫秒。现在广播:

    t0 = time.clock()
    x_shift, y_shift = np.meshgrid(range(ps), range(ps))
    x, y = np.unravel_index(IDX, (q,q))
    #roi for region of interest
    roix = x[:,:,None]+x_shift; 
    roiy = y[:,:,None]+y_shift;
    roix[roix>q-1] = q-1; roiy[roiy>q-1] = q-1;
    I_copy.reshape(N,q,q)[np.arange(N)[:, None, None], roix, roiy] = 2.0
    
    print('No loop: {:.2f} ms'.format(time.clock()*1e3-t0*1e3))
    
    print(np.array_equal(I, I_copy))
    

    大约快80倍