how to optimize matrix multiplication (matmul) code to run fast on a single processor core
Solution 1:
The state-of-the-art implementation of matrix multiplication on CPUs uses GotoBLAS algorithm. Basically the loops are organized in the following order:
Loop5 for jc = 0 to N-1 in steps of NC
Loop4 for kc = 0 to K-1 in steps of KC
//Pack KCxNC block of B
Loop3 for ic = 0 to M-1 in steps of MC
//Pack MCxKC block of A
//--------------------Macro Kernel------------
Loop2 for jr = 0 to NC-1 in steps of NR
Loop1 for ir = 0 to MC-1 in steps of MR
//--------------------Micro Kernel------------
Loop0 for k = 0 to KC-1 in steps of 1
//update MRxNR block of C matrix
A key insight underlying modern high-performance implementations of matrix multiplication is to organize the computations by partitioning the operands into blocks for temporal locality (3 outer most loops), and to pack (copy) such blocks into contiguous buffers that fit into various levels of memory for spatial locality (3 inner most loops).
The above figure (originally from this paper, directly used in this tutorial) illustrates the GotoBLAS algorithm as implemented in BLIS. Cache blocking parameters {MC, NC, KC} determine the submatrix sizes of Bp (KC × NC) and Ai (MC × KC), such that they fit in various caches. During the computation, row panels Bp are contiguously packed into buffer Bp to fit in the L3 cache. Blocks Ai are similarly packed into buffer Ai to fit in the L2 cache. Register block sizes {MR, NR} relate to submatrices in registers that contribute to C. In the micro-kernel (the inner most loop), a small MR × NR micro-tile of C is updated by pair of MR × KC and KC × NR slivers of Ai and Bp.
For the Strassen's algorithm with O(N^2.87) complexity, you might be interested in reading this paper. Other fast matrix multiplication algorithms with asymptotic complexity less than O(N^3) can be easily extended in this paper. There is a recent thesis about the practical fast matrix multiplication algorithms.
The following tutorials might be helpful if you want to learn more about how to optimize matrix multiplication on CPUs:
How to Optimize GEMM Wiki
GEMM: From Pure C to SSE Optimized Micro Kernels
BLISlab: A sandbox for optimizing GEMM for CPU and ARM
A most updated document about how to optimize GEMM on CPUs (with AVX2/FMA) step by step can be downloaded here: https://github.com/ULAFF/LAFF-On-HPC/blob/master/LAFF-On-PfHP.pdf
A Massive Open Online Course to be offered on edX starting in June 2019 (LAFF-On Programming for High Performance): https://github.com/ULAFF/LAFF-On-HPC http://www.cs.utexas.edu/users/flame/laff/pfhp/LAFF-On-PfHP.html
Solution 2:
My C i quite rusty, and I don't know what of the following the optimizer is already doing, but here goes...
Since virtually all the time is spent doing a dot product, let me just optimize that; you can build from there.
double* pa = &A[i];
double* pb = &B[j*n];
double* pc = &C[i+j*n];
for( int k = 0; k < n; k++ )
{
*pc += *pa++ * *pb;
pb += n;
}
Your code is probably spending more time on subscript arithmetic than anything else. My code uses +=8
and +=(n<<3)
, which is a lot more efficient. (Note: a double
takes 8
bytes.)
Other optimizations:
If you know the value of n
, you could "unroll" at least the innermost loop. This eliminates the overhead of the for
.
Even if you only knew that n
was even, you could iterate n/2 times, doubling up on the code in each iteration. This would cut the for
overhead in half (approx).
I did not check to see if the matrix multiply could be better done in row-major versus column-major order. +=8
is faster than +=(n<<3)
; this would be a small improvement in the outer loops.
Another way to "unroll" would be to do two dot-products in the same inner loop. (I guess I am getting too complex to even explain.)
CPUs are "hyper-scalar" these days. This means that they can, to some extent, do multiple things at the same time. But it does not mean that things that must be done consecutively can be optimized that way. Doing two independent dot products in the same loop may provide more opportunities for hyperscaling.
Solution 3:
There are a lot of ways of straight forward improvements. Basic optimization is what Rick James wrote. Furthermore you can rearrange the first matrix by rows and the second one by columns. Then in your for() loops you will always do ++ and never do +=n. Loops where you jump by n are much slower in comparison to ++.
But most of those optimizations do hold the punch because a good compiler will do them for you when you use -O3 or -O4 flags. It will unroll the loops, reuse registers, do logical operations instead of multiplications etc. It will even change the order of your for i
and for j
loops if necessary.
The core problem with your code is that when you have NxN matrices, you use 3 loops forcing you to do O(N^3)
operations. This is very slow. I think that state of the art algorithms do only ~O(N^2.37)
operations (link here). For large matrices (say N = 5000) this is a hell of a strong optimization. You can implement the Strassen algorithm easily which will give you ~N^2.87 improvement or use in combination of Karatsuba algorithm Which can speed things up even for regular scalar optimizations. Don't implement anything on your own. Download an opensource implementation. Multiplying matrices as a huge topic with a lot of research and very fast algorithms. Using 3 loops is not considered a valid way to do this work efficiently. Good luck
Solution 4:
Instead of optimizing, you can obfuscate the code to make it look like it is optimized.
Here is a matrix multiplication with a single null bodied for
loop(!):
/* This routine performs a dgemm operation
* C := C + A * B
* where A, B, and C are lda-by-lda matrices stored in column-major format.
* On exit, A and B maintain their input values.
* This implementation uses a single for loop: it has been optimised for space,
* namely vertical space in the source file! */
void square_dgemm(int n, const double *A, const double *B, double *C) {
for (int i = 0, j = 0, k = -1;
++k < n || ++j < n + (k = 0) || ++i < n + (j = 0);
C[i+j*n] += A[i+k*n] * B[k+j*n]) {}
}