代码之家  ›  专栏  ›  技术社区  ›  duhaime

numpy:变形的数组行为异常

  •  0
  • duhaime  · 技术社区  · 6 年前

    我正试图重塑一个麻木的数组[ link ]然后重新调整数组的形状,但不能达到我想要的结果。我的数据从形状开始 (n_vertices, n_time, n_dimensions) . 然后我把它变成形状 (n_time, n_vertices * n_dimensions) :

    import numpy as np
    
    X = np.load('dance.npy')
    
    n_vertices, n_time, n_dims = X.shape    
    
    X = X.reshape(n_time, n_vertices * n_dims)
    

    通过可视化数据,我可以看到上面的转换不会扭曲内部值:

    import mpl_toolkits.mplot3d.axes3d as p3
    from mpl_toolkits.mplot3d.art3d import juggle_axes
    import matplotlib.pyplot as plt
    from IPython.display import HTML
    from matplotlib import animation
    import matplotlib
    
    matplotlib.rcParams['animation.embed_limit'] = 2**128
    
    def update_points(time, points, df):
      points._offsets3d = juggle_axes(df[:,time,0], df[:,time,1], df[:,time,2], 'z')
    
    def get_plot(df, lim=1, frames=200, duration=45, time_axis=1, reshape=False):
      if reshape: df = df.reshape(n_vertices, df.shape[time_axis], n_dims)
      fig = plt.figure()
      ax = p3.Axes3D(fig)
      ax.set_xlim(-lim, lim)
      ax.set_ylim(-lim, lim)
      ax.set_zlim(-lim, lim)
      points = ax.scatter(df[:,0,0], df[:,0,1], df[:,0,2], depthshade=False) # x,y,z vals
      return animation.FuncAnimation(fig, update_points, frames, interval=duration, fargs=(points, df), blit=False ).to_jshtml()
    
    HTML(get_plot(X, frames=200, time_axis=0, reshape=True))
    

    这将显示运动中的数据(顶点是舞者的身体部分,可视化看起来像人体)。这都很好。但是,当我试图可视化数据的前10个时间段时,结果图没有显示上面可视化的前几帧——形式实际上不是人形的:

    HTML(get_plot(X[:20], frames=10, time_axis=0, reshape=True))
    

    有人能帮我理解为什么这个切片操作与X的前几个时间帧不匹配吗?任何建议或意见都是非常有帮助的。

    1 回复  |  直到 6 年前
        1
  •  0
  •   duhaime    6 年前

    结果发现我的整形操作并没有像我想象的那样操纵数组。以下函数将我的原始数组x重新整形为扁平形式(有两个轴),然后正确地恢复为非扁平形式(有三个轴)。我添加了评论和测试,以确保一切都如预期的那样:

    from math import floor
    
    def flatten(df, run_tests=True):
      '''
      df is a numpy array with the following three axes:
        df.shape[0] = the index of a vertex
        df.shape[1] = the index of a time stamp
        df.shape[2] = the index of a dimension (x, y, z)
      So df[1][0][2] is the value for the 1st vertex (0-based) at time 0 in dimension 2 (z).
      To flatten this dataframe will mean to push the data into shape:
        flattened.shape[0] = time index
        flattened.shape[1] = [vertex_index*3] + dimension_vertex
      So flattened[1][3] will be the 3rd dimension of the 1st index (0-based) at time 1. 
      '''
      if run_tests:
        assert df.shape == X.shape and np.all(df == X)
    
      # reshape X such that flattened.shape = time, [x0, y0, z0, x1, y1, z1, ... xn-1, yn-1, zn-1]
      flattened = X.swapaxes(0, 1).reshape( (df.shape[1], df.shape[0] * df.shape[2]), order='C' )
    
      if run_tests: # switch to false to skip tests
        for idx, i in enumerate(df):
          for jdx, j in enumerate(df[idx]):
            for kdx, k in enumerate(df[idx][jdx]):
              assert flattened[jdx][ (idx*df.shape[2]) + kdx ] == df[idx][jdx][kdx]
    
      return flattened
    

    并解开扁平数据:

    def unflatten(df, run_tests=True):
      '''
      df is a numpy array with the following two axes:
        df.shape[0] = time index
        df.shape[1] = [vertex_index*3] + dimension_vertex
    
      To unflatten this dataframe will mean to push the data into shape:
        unflattened.shape[0] = the index of a vertex
        unflattened.shape[1] = the index of a time stamp
        unflattened.shape[2] = the index of a dimension (x, y, z)
    
      So df[2][4] == unflattened[1][2][0]
      '''
      if run_tests:
        assert (len(df.shape) == 2) and (df.shape[1] == X.shape[0] * X.shape[2])
    
      unflattened = np.zeros(( X.shape[0], df.shape[0], X.shape[2] ))
    
      for idx, i in enumerate(df):
        for jdx, j in enumerate(df[idx]):
          kdx = floor(jdx / 3)
          ldx = jdx % 3
          unflattened[kdx][idx][ldx] = df[idx][jdx]
    
      if run_tests: # set to false to skip tests
        for idx, i in enumerate(unflattened):
          for jdx, j in enumerate(unflattened[idx]):
            for kdx, k in enumerate(unflattened[idx][jdx]):
              assert( unflattened[idx][jdx][kdx] == X[idx][jdx][kdx] )
    
      return unflattened
    

    然后可视化:

    import mpl_toolkits.mplot3d.axes3d as p3
    from mpl_toolkits.mplot3d.art3d import juggle_axes
    import matplotlib.pyplot as plt
    from IPython.display import HTML
    from matplotlib import animation
    import matplotlib
    
    # ask matplotlib to plot up to 2^128 frames in animations
    matplotlib.rcParams['animation.embed_limit'] = 2**128
    
    def update_points(time, points, df):
      points._offsets3d = juggle_axes(df[:,time,0], df[:,time,1], df[:,time,2], 'z')
    
    def get_plot(df, lim=1, frames=200, duration=45):
      if len(df.shape) == 2: df = unflatten(df)
      fig = plt.figure()
      ax = p3.Axes3D(fig)
      ax.set_xlim(-lim, lim)
      ax.set_ylim(-lim, lim)
      ax.set_zlim(-lim, lim)
      points = ax.scatter(df[:,0,0], df[:,0,1], df[:,0,2], depthshade=False) # x,y,z vals
      return animation.FuncAnimation(fig,
        update_points,
        frames,
        interval=duration,
        fargs=(points, df),
        blit=False  
      ).to_jshtml()
    
    HTML(get_plot(unflat, frames=200))
    

    这使我可以毫无问题地切割时间轴:

    flat = flatten(X)
    unflat = unflatten(flat)
    
    HTML(get_plot(unflat, frames=200))
    HTML(get_plot(flat[:20], frames=20))
    HTML(get_plot(unflat[:,:20,:], frames=20))