Recently, I came across times 3D matrix multiplication in Matlab. It is for frequency domain signal processing, specifically, frequency-dependent multi-channel Wiener filter. Here, 3D matrix multiplication means 2D matrix multiplication at each frequency (or other parameters): Given and which are matrices (vectors) at each frequency , derive .
In Matlab, we hold in a 3D matrix , and hold in . Suppose frequency is along the first dimension, then one plain implementation of 3D matrix multiplication is
A1=permute(A,[2 3 1]); B1=permute(B,[2 3 1]); for f=1:F C(f,:,:)=A1(:,:,f)*B1(:,:,f); end
But if is large while is small (for example, ), as in multi-channel Wiener filter, the loop along frequency drags speed down significantly.
The native power of Matlab lies in matrix and vector operation, which is lightning fast. (JIT helps a lot in some cases, but not always.) Here the strategy is to replace the loop along frequency by loop along the second and third dimensions, and vectorize the inner operations:
A1=permute(A,[1 3 2]); for m=1:M for n=1:N C(:,m,n)=sum(A1(:,:,m).*B(:,:,n),2); end end
The total number of iterations now is , instead of . Using this strategy, I observed more than 10X speeding up for , and .
In the special case of , the above code can be simplified further:
B1=squeeze(B); for n=1:N C(:,:,n)=A.*repmat(B1(:,n),1,M); end
Another important case is is a diagonal matrix at each frequency, we can find in one stroke:
C=repmat(A(:,1:M+1:M*M),[1 1 N]).*B;
It feels very good to see results jumping out right away.