## 3D Matrix Multiplication Tricks

Recently, I came across times 3D matrix multiplication in Matlab. It is for frequency domain signal processing, specifically, frequency-dependent multi-channel Wiener filter. Here, 3D matrix multiplication means 2D matrix multiplication at each frequency (or other parameters):  Given $\mathbf{A}(f)$ and $\mathbf{B}(f)$ which are matrices (vectors) at each frequency $f$, derive $\mathbf{C}(f)=\mathbf{A}(f)\mathbf{B}(f)$.

In Matlab, we hold $\mathbf{A}(f)$ in a 3D matrix $\mathtt{A(1:F,1:M,1:K)}$, and hold $\mathbf{B}(f)$ in $\mathtt{B(1:F,1:K,1:N)}$. Suppose frequency is along the first dimension, then one plain implementation of 3D matrix multiplication is

A1=permute(A,[2 3 1]);
B1=permute(B,[2 3 1]);
for f=1:F
C(f,:,:)=A1(:,:,f)*B1(:,:,f);
end

But if $\mathtt{F}$ is large while $\mathtt{M*N}$ is small (for example, $\mathtt{F=1024,M=2,N=5}$), as in multi-channel Wiener filter, the loop along frequency drags speed down significantly.

The native power of Matlab lies in matrix and vector operation, which is lightning fast. (JIT helps a lot in some cases, but not always.) Here the strategy is to replace the loop along frequency by loop along the second and third dimensions, and vectorize the inner operations:

 A1=permute(A,[1 3 2]);
for m=1:M
for n=1:N
C(:,m,n)=sum(A1(:,:,m).*B(:,:,n),2);
end
end

The total number of iterations now is $\mathtt{M*N}$, instead of $\mathtt{F}$. Using this strategy, I observed more than 10X speeding up for $\mathtt{F=1024,M=2}$, and $\mathtt{N=5}$.

In the special case of $\mathtt{K==1}$, the above code can be simplified further:

B1=squeeze(B);
for n=1:N
C(:,:,n)=A.*repmat(B1(:,n),1,M);
end

Another important case is $\mathtt{A(f,:,:)}$ is a diagonal matrix at each frequency, we can find $\mathtt{C(1:F,1:M,1:N)}$ in one stroke:

C=repmat(A(:,1:M+1:M*M),[1 1 N]).*B;

It feels very good to see results jumping out right away.