代码之家  ›  专栏  ›  技术社区  ›  Hitesh

如何在keras中保存模型过滤器

  •  1
  • Hitesh  · 技术社区  · 6 年前

    我正在使用来自 here ,具体如下:

    from mpl_toolkits.axes_grid1 import make_axes_locatable
    def nice_imshow(ax, data, vmin=None, vmax=None, cmap=None):
        """Wrapper around pl.imshow"""
        if cmap is None:
            cmap = cm.jet
        if vmin is None:
            vmin = data.min()
        if vmax is None:
            vmax = data.max()
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        im = ax.imshow(data, vmin=vmin, vmax=vmax, interpolation='nearest', cmap=cmap)
        pl.colorbar(im, cax=cax)
    #    pl.savefig("featuremaps--{}".format(layer_num) + '.jpg')
    
    import numpy.ma as ma
    def make_mosaic(imgs, nrows, ncols, border=1):
        """
        Given a set of images with all the same shape, makes a
        mosaic with nrows and ncols
        """
        nimgs = imgs.shape[0]
        imshape = imgs.shape[1:]
    
        mosaic = ma.masked_all((nrows * imshape[0] + (nrows - 1) * border,
                                ncols * imshape[1] + (ncols - 1) * border),
                                dtype=np.float32)
    
        paddedh = imshape[0] + border
        paddedw = imshape[1] + border
        for i in range(nimgs):
            row = int(np.floor(i / ncols))
            col = i % ncols
    
            mosaic[row * paddedh:row * paddedh + imshape[0],
                   col * paddedw:col * paddedw + imshape[1]] = imgs[i]
        return mosaic
    
    
    # Visualize weights
    W=model.layers[8].get_weights()[0][:,:,0,:]
    W=np.swapaxes(W,0,2)
    W = np.squeeze(W)
    print("W shape : ", W.shape)
    
    pl.figure(figsize=(15, 15))
    pl.title('conv1 weights')
    nice_imshow(pl.gca(), make_mosaic(W, 8, 8), cmap=cm.binary)
    

    我想保存过滤器图像。我们通常使用 fig.savefig("featuremaps-kernel-{}".format(layer_num) + '.jpg') 为了保存数据。但在这种情况下不起作用,可能是因为nice_u函数。请帮助我必须写什么命令来保存图形使用命令不是手动的。因为如果有大的网络,就有很多手工工作。

    1 回复  |  直到 6 年前
        1
  •  1
  •   filippo    6 年前

    我也遇到过类似的问题,我想用 plt.savefig . 它总是导致空白图像。

    我从来没有真正弄清楚它为什么会发生,如果我记得正确的话,它只发生在使用多处理时,但我可能错了。

    我用一个非交互的后端解决了这个问题,如果你不打算用 plt.show() .

    在matplotlib导入的顶部添加

    import matplotlib as mpl
    mpl.use('Agg')
    

    另外,如果在某个时刻保存了许多这样的图像,matplotlib会抱怨打开的图形太多。你应该加一个 plt.close() 每次通话后 保存图 .

    很抱歉,这纯粹是个轶事,也许有人会有更好的见解。