代码之家  ›  专栏  ›  技术社区  ›  Jacob

将三维矩阵与二维矩阵相乘

  •  25
  • Jacob  · 技术社区  · 15 年前

    假设我有一个 AXBXC 矩阵 X 和A BXD 矩阵 Y .

    有没有一个非循环的方法,我可以用它来乘以 C AXB 矩阵 Y ?

    9 回复  |  直到 6 年前
        1
  •  15
  •   gnovice    14 年前

    您可以使用函数在一行中执行此操作 NUM2CELL 打破矩阵 X 在一个单元数组中 CELLFUN 跨单元操作:

    Z = cellfun(@(x) x*Y,num2cell(X,[1 2]),'UniformOutput',false);
    

    结果 Z 是一个 1-B-C 单元格数组,其中每个单元格包含 A—D 矩阵。如果你想要 Z轴 成为一个 A -B-B-C 矩阵,你可以使用 CAT 功能:

    Z = cat(3,Z{:});
    



    注: 我以前用的解决方案 MAT2CELL 而不是 NUM2单元 不是那么简单:

    [A,B,C] = size(X);
    Z = cellfun(@(x) x*Y,mat2cell(X,A,B,ones(1,C)),'UniformOutput',false);
    
        2
  •  16
  •   Zaid    15 年前

    作为个人偏好,我喜欢我的代码尽可能简洁易读。

    以下是我将要做的,尽管它不满足您的“无循环”要求:

    for m = 1:C
    
        Z(:,:,m) = X(:,:,m)*Y;
    
    end
    

    这将导致 A×D×C 矩阵 Z .

    当然,您可以通过使用 Z = zeros(A,D,C); .

        3
  •  8
  •   Amro    15 年前

    这是一个一行解决方案(如果要拆分为第三维度,则为两行):

    A = 2;
    B = 3;
    C = 4;
    D = 5;
    
    X = rand(A,B,C);
    Y = rand(B,D);
    
    %# calculate result in one big matrix
    Z = reshape(reshape(permute(X, [2 1 3]), [A B*C]), [B A*C])' * Y;
    
    %'# split into third dimension
    Z = permute(reshape(Z',[D A C]),[2 1 3]);
    

    因此现在: Z(:,:,i) 包含的结果 X(:,:,i) * Y


    说明:

    上面的内容可能看起来很混乱,但这个想法很简单。 首先,我从三维开始 X 并沿第一个dim执行垂直连接:

    XX = cat(1, X(:,:,1), X(:,:,2), ..., X(:,:,C))
    

    ……困难在于 C 是变量,因此不能使用 维特卡特 . 接下来我们用这个乘以 Y :

    ZZ = XX * Y;
    

    最后,我将它分解为第三个维度:

    Z(:,:,1) = ZZ(1:2, :);
    Z(:,:,2) = ZZ(3:4, :);
    Z(:,:,3) = ZZ(5:6, :);
    Z(:,:,4) = ZZ(7:8, :);
    

    所以你可以看到它只需要一个矩阵乘法,但是你必须 重塑 前后矩阵。

        4
  •  5
  •   cspence    9 年前

    我正在处理完全相同的问题,并着眼于最有效的方法。我看到的方法大致有三种,除了使用外部库(即, mtimesx ):

    1. 循环通过三维矩阵的切片
    2. 重复和排列巫术
    3. Cellfun乘法

    我最近比较了这三种方法,看哪种方法最快。我的直觉是(2)会是赢家。代码如下:

    % generate data
    A = 20;
    B = 30;
    C = 40;
    D = 50;
    
    X = rand(A,B,C);
    Y = rand(B,D);
    
    % ------ Approach 1: Loop (via @Zaid)
    tic
    Z1 = zeros(A,D,C);
    for m = 1:C
        Z1(:,:,m) = X(:,:,m)*Y;
    end
    toc
    
    % ------ Approach 2: Reshape+Permute (via @Amro)
    tic
    Z2 = reshape(reshape(permute(X, [2 1 3]), [A B*C]), [B A*C])' * Y;
    Z2 = permute(reshape(Z2',[D A C]),[2 1 3]);
    toc
    
    
    % ------ Approach 3: cellfun (via @gnovice)
    tic
    Z3 = cellfun(@(x) x*Y,num2cell(X,[1 2]),'UniformOutput',false);
    Z3 = cat(3,Z3{:});
    toc
    

    所有三种方法产生相同的输出(phew!)但是,令人惊讶的是,这个循环是最快的:

    Elapsed time is 0.000418 seconds.
    Elapsed time is 0.000887 seconds.
    Elapsed time is 0.001841 seconds.
    

    请注意,从一个试验到另一个试验的时间变化很大,有时(2)是最慢的。随着数据的增加,这些差异变得更加显著。但与 许多的 更大的数据,(3)比(2)。循环方法仍然是最好的。

    % pretty big data...
    A = 200;
    B = 300;
    C = 400;
    D = 500;
    Elapsed time is 0.373831 seconds.
    Elapsed time is 0.638041 seconds.
    Elapsed time is 0.724581 seconds.
    
    % even bigger....
    A = 200;
    B = 200;
    C = 400;
    D = 5000;
    Elapsed time is 4.314076 seconds.
    Elapsed time is 11.553289 seconds.
    Elapsed time is 5.233725 seconds.
    

    但是循环法 可以 如果环尺寸比其他尺寸大得多,则应慢于(2)。

    A = 2;
    B = 3;
    C = 400000;
    D = 5;
    Elapsed time is 0.780933 seconds.
    Elapsed time is 0.073189 seconds.
    Elapsed time is 2.590697 seconds.
    

    所以(2)以一个大因素获胜,在这个(可能是极端的)情况下。可能并不是所有情况下都是最佳的方法,但是循环仍然很好,在许多情况下也是最佳的。在可读性方面也是最好的。走开!

        5
  •  1
  •   Rook    15 年前

    不。有几种方法,但总是以循环的形式出现,直接的或间接的。

    为了取悦我的好奇心,你为什么要那样做?

        6
  •  1
  •   user649198    11 年前

    要回答这个问题, 有关可读性,请参见:

    • ndmult ,Ajunapi(Juan Pablo Carbajal),2013年,GNU GPL

    输入

    • 2阵
    • 昏暗的

    例子

     nT = 100;
     t = 2*pi*linspace (0,1,nT)’;
    
     # 2 experiments measuring 3 signals at nT timestamps
     signals = zeros(nT,3,2);
     signals(:,:,1) = [sin(2*t) cos(2*t) sin(4*t).^2];
     signals(:,:,2) = [sin(2*t+pi/4) cos(2*t+pi/4) sin(4*t+pi/6).^2];
    
     sT(:,:,1) = signals(:,:,1)’;
     sT(:,:,2) = signals(:,:,2)’;
       G = ndmult (signals,sT,[1 2]);
    

    来源

    原始来源。我添加了内联注释。

    function M = ndmult (A,B,dim)
      dA = dim(1);
      dB = dim(2);
    
      # reshape A into 2d
      sA = size (A);
      nA = length (sA);
      perA = [1:(dA-1) (dA+1):(nA-1) nA dA](1:nA);
      Ap = permute (A, perA);
      Ap = reshape (Ap, prod (sA(perA(1:end-1))), sA(perA(end)));
    
      # reshape B into 2d
      sB = size (B);
      nB = length (sB);
      perB = [dB 1:(dB-1) (dB+1):(nB-1) nB](1:nB);
      Bp = permute (B, perB);
      Bp = reshape (Bp, sB(perB(1)), prod (sB(perB(2:end))));
    
      # multiply
      M = Ap * Bp;
    
      # reshape back to original format
      s = [sA(perA(1:end-1)) sB(perB(2:end))];
      M = squeeze (reshape (M, s));
    endfunction
    
        7
  •  1
  •   Ali Mirzaei    7 年前

    我强烈建议你使用 MMX toolbox MATLAB的。它可以尽可能快地将n维矩阵相乘。

    的优势 MMX 是:

    1. 它是 容易的 使用。
    2. 乘法 n维矩阵 (实际上它可以使二维矩阵的数组相乘)
    3. 它执行其他 矩阵运算 (转置、二次乘、CHOL分解等)
    4. 它使用 C编译器 多线程 加速计算。

    对于这个问题,只需编写以下命令:

    C=mmx('mul',X,Y);
    

    这里是所有可能方法的基准。有关更多详细信息,请参阅 question .

        1.6571 # FOR-loop
        4.3110 # ARRAYFUN
        3.3731 # NUM2CELL/FOR-loop/CELL2MAT
        2.9820 # NUM2CELL/CELLFUN/CELL2MAT
        0.0244 # Loop Unrolling
        0.0221 # MMX toolbox  <===================
    
        8
  •  0
  •   Kevin    15 年前

    我会认为递归,但这是你唯一能做的其他非循环方法。

        9
  •  0
  •   µBio    15 年前

    您可以“展开”循环,即按顺序写出循环中可能发生的所有乘法。