我想你需要计算一下Pythorch的工作原理
BxCxHxW : number of mini-batches, channels, height, width
格式,也可以使用
matmul
,因为
bmm
使用张量或ndim/dim/rank=3。
我知道你可能会在网上找到这个,但无论如何:
batch1 = torch.randn(10, 3, 20, 10)
batch2 = torch.randn(10, 3, 10, 30)
res = torch.matmul(batch1, batch2)
res.size() # torch.Size([10, 3, 20, 30])