Matrix multiplication example

Matrix multiplication is an operation performed in many data intensive applications. It is made up of groups of arithmetic operations which are repeated in a straightforward way:

Neon Optimizing with C Code Matrix Diagram 

The matrix multiplication process is as follows:

  • A- Take a row in the first matrix
  • B- Perform a dot product of this row with a column from the second matrix
  • C- Store the result in the corresponding row and column of a new matrix

For matrices of 32-bit floats, the multiplication could be written as:

void matrix_multiply_c(float32_t *A, float32_t *B, float32_t *C, uint32_t n, uint32_t m, uint32_t k) {
    for (int i_idx=0; i_idx < n; i_idx++) {
        for (int j_idx=0; j_idx < m; j_idx++) {
            C[n*j_idx + i_idx] = 0;
            for (int k_idx=0; k_idx < k; k_idx++) {
                C[n*j_idx + i_idx] += A[n*k_idx + i_idx]*B[k*j_idx + k_idx];
            }
        }
    }
}

We have assumed a column-major layout of the matrices in memory. That is, an n x m matrix M is represented as an array M_array where Mij = M_array[n*j + i].

This code is suboptimal, since it does not make full use of Neon. We can begin to improve it by using intrinsics, but let’s tackle a simpler problem first by looking at small, fixed-size matrices before moving on to larger matrices.

The following code uses intrinsics to multiply two 4x4 matrices. Since we have a small and fixed number of values to process, all of which can fit into the processor’s Neon registers at once, we can completely unroll the loops.

void matrix_multiply_4x4_neon(float32_t *A, float32_t *B, float32_t *C) {
	// these are the columns A
	float32x4_t A0;
	float32x4_t A1;
	float32x4_t A2;
	float32x4_t A3;
	
	// these are the columns B
	float32x4_t B0;
	float32x4_t B1;
	float32x4_t B2;
	float32x4_t B3;
	
	// these are the columns C
	float32x4_t C0;
	float32x4_t C1;
	float32x4_t C2;
	float32x4_t C3;
	
	A0 = vld1q_f32(A);
	A1 = vld1q_f32(A+4);
	A2 = vld1q_f32(A+8);
	A3 = vld1q_f32(A+12);
	
	// Zero accumulators for C values
	C0 = vmovq_n_f32(0);
	C1 = vmovq_n_f32(0);
	C2 = vmovq_n_f32(0);
	C3 = vmovq_n_f32(0);
	
	// Multiply accumulate in 4x1 blocks, i.e. each column in C
	B0 = vld1q_f32(B);
	C0 = vfmaq_laneq_f32(C0, A0, B0, 0);
	C0 = vfmaq_laneq_f32(C0, A1, B0, 1);
	C0 = vfmaq_laneq_f32(C0, A2, B0, 2);
	C0 = vfmaq_laneq_f32(C0, A3, B0, 3);
	vst1q_f32(C, C0);
	
	B1 = vld1q_f32(B+4);
	C1 = vfmaq_laneq_f32(C1, A0, B1, 0);
	C1 = vfmaq_laneq_f32(C1, A1, B1, 1);
	C1 = vfmaq_laneq_f32(C1, A2, B1, 2);
	C1 = vfmaq_laneq_f32(C1, A3, B1, 3);
	vst1q_f32(C+4, C1);
	
	B2 = vld1q_f32(B+8);
	C2 = vfmaq_laneq_f32(C2, A0, B2, 0);
	C2 = vfmaq_laneq_f32(C2, A1, B2, 1);
	C2 = vfmaq_laneq_f32(C2, A2, B2, 2);
	C2 = vfmaq_laneq_f32(C2, A3, B2, 3);
	vst1q_f32(C+8, C2);
	
	B3 = vld1q_f32(B+12);
	C3 = vfmaq_laneq_f32(C3, A0, B3, 0);
	C3 = vfmaq_laneq_f32(C3, A1, B3, 1);
	C3 = vfmaq_laneq_f32(C3, A2, B3, 2);
	C3 = vfmaq_laneq_f32(C3, A3, B3, 3);
	vst1q_f32(C+12, C3);
}

We have chosen to multiply fixed size 4x4 matrices for a few reasons:

  • Some applications need 4x4 matrices specifically, for example graphics or relativistic physics.
  • The Neon vector registers hold four 32-bit values, so matching the program to the architecture will make it easier to optimize.
  • We can take this 4x4 kernel and use it in a more general one.

Let's summarize the intrinsics that have been used here:

Code element What is it? Why are we using it?
float32x4_t An array of four 32-bit floats. One uint32x4_t fits into a 128-bit register. We can ensure there are no wasted register bits even in C code.
vld1q_f32(…) A function which loads four 32-bit floats into a float32x4_t. To get the matrix values we need from A and B.
vfmaq_lane_f32(…) A function which uses the fused multiply accumulate instruction. Multiplies a float32x4_t value by a single element of another float32x4_t then adds the result to a third float32x4_t before returning the result. Since the matrix row-on-column dot products are a set of multiplications and additions, this operation fits quite naturally.
vst1q_f32(…) A function which stores a float32x4_t at a given address. To store the results after they are calculated.

Now that we can multiply a 4x4 matrix, we can multiply larger matrices by treating them as blocks of 4x4 matrices. A flaw with this approach is that it only works with matrix sizes which are a multiple of four in both dimensions, but by padding any matrix with zeroes you can use this method without changing it.

The code for a more general matrix multiplication is listed below. The structure of the kernel has changed very little, with the addition of loops and address calculations being the major changes. As in the 4x4 kernel we have used unique variable names for the columns of B, even though we could have used one variable and re-loaded. This acts as a hint to the compiler to assign different registers to these variables, which will enable the processor to complete the arithmetic instructions for one column while waiting on the loads for another.

void matrix_multiply_neon(float32_t  *A, float32_t  *B, float32_t *C, uint32_t n, uint32_t m, uint32_t k) {
	/* 
	 * Multiply matrices A and B, store the result in C. 
	 * It is the user's responsibility to make sure the matrices are compatible.
	 */	

	int A_idx;
	int B_idx;
	int C_idx;
	
	// these are the columns of a 4x4 sub matrix of A
	float32x4_t A0;
	float32x4_t A1;
	float32x4_t A2;
	float32x4_t A3;
	
	// these are the columns of a 4x4 sub matrix of B
	float32x4_t B0;
	float32x4_t B1;
	float32x4_t B2;
	float32x4_t B3;
	
	// these are the columns of a 4x4 sub matrix of C
	float32x4_t C0;
	float32x4_t C1;
	float32x4_t C2;
	float32x4_t C3;
	
	for (int i_idx=0; i_idx<n; i_idx+=4 {
for (int j_idx=0; j_idx<m; j_idx+=4){
// zero accumulators before matrix op
c0=vmovq_n_f32(0);
c1=vmovq_n_f32(0);
c2=vmovq_n_f32(0);
c3=vmovq_n_f32(0);
for (int k_idx=0; k_idx<k; k_idx+=4){
// compute base index to 4x4 block
a_idx = i_idx + n*k_idx;
b_idx = k*j_idx k_idx;

// load most current a values in row
A0=vld1q_f32(A+A_idx);
A1=vld1q_f32(A+A_idx+n);
A2=vld1q_f32(A+A_idx+2*n);
A3=vld1q_f32(A+A_idx+3*n);

// multiply accumulate 4x1 blocks, i.e. each column C
B0=vld1q_f32(B+B_idx);
C0=vfmaq_laneq_f32(C0,A0,B0,0);
C0=vfmaq_laneq_f32(C0,A1,B0,1);
C0=vfmaq_laneq_f32(C0,A2,B0,2);
C0=vfmaq_laneq_f32(C0,A3,B0,3);

B1=v1d1q_f32(B+B_idx+k);
C1=vfmaq_laneq_f32(C1,A0,B1,0);
C1=vfmaq_laneq_f32(C1,A1,B1,1);
C1=vfmaq_laneq_f32(C1,A2,B1,2);
C1=vfmaq_laneq_f32(C1,A3,B1,3);

B2=vld1q_f32(B+B_idx+2*k);
C2=vfmaq_laneq_f32(C2,A0,B2,0);
C2=vfmaq_laneq_f32(C2,A1,B2,1);
C2=vfmaq_laneq_f32(C2,A2,B2,2);
C2=vfmaq_laneq_f32(C2,A3,B3,3);

B3=vld1q_f32(B+B_idx+3*k);
C3=vfmaq_laneq_f32(C3,A0,B3,0);
C3=vfmaq_laneq_f32(C3,A1,B3,1);
C3=vfmaq_laneq_f32(C3,A2,B3,2);
C3=vfmaq_laneq_f32(C3,A3,B3,3);
}
//Compute base index for stores
C_idx = n*j_idx + i_idx;
vstlq_f32(C+C_idx, C0);
vstlq_f32(C+C_idx+n,Cl);
vstlq_f32(C+C_idx+2*n,C2);
vstlq_f32(C+C_idx+3*n,C3);
}
}
}

Compiling and disassembling this function, and comparing it with our C function shows:

  • Fewer arithmetic instructions for a given matrix multiplication, since we are leveraging the Advanced SIMD technology with full register packing. Pure C code generally does not do this.
  • FMLA instead of FMUL instructions. As specified by the intrinsics.
  • Fewer loop iterations. When used properly intrinsics allow loops to be unrolled easily.
  • However, there are unnecessary loads and stores due to memory allocation and initialization of data types (for example, float32x4_t) which are not used in the pure C code.
  • Full source code example: Matrix multiplication
    /*
     * Copyright (C) Arm Limited, 2019 All rights reserved. 
     * 
     * The example code is provided to you as an aid to learning when working 
     * with Arm-based technology, including but not limited to programming tutorials. 
     * Arm hereby grants to you, subject to the terms and conditions of this Licence, 
     * a non-exclusive, non-transferable, non-sub-licensable, free-of-charge licence, 
     * to use and copy the Software solely for the purpose of demonstration and 
     * evaluation.
     * 
     * You accept that the Software has not been tested by Arm therefore the Software 
     * is provided "as is", without warranty of any kind, express or implied. In no 
     * event shall the authors or copyright holders be liable for any claim, damages 
     * or other liability, whether in action or contract, tort or otherwise, arising 
     * from, out of or in connection with the Software or the use of Software.
     */
    
    #include <stdio.h>
    #include <stdint.h>
    #include <stdlib.h>
    #include <stdbool.h>
    #include <math.h>
    
    #include <arm_neon.h>
    
    #define BLOCK_SIZE 4
    
    
    void matrix_multiply_c(float32_t *A, float32_t *B, float32_t *C, uint32_t n, uint32_t m, uint32_t k) {
    	for (int i_idx=0; i_idx<n; i_idx++) {
    		for (int j_idx=0; j_idx<m; j_idx++) {
    			C[n*j_idx + i_idx] = 0;
    			for (int k_idx=0; k_idx<k; k_idx++) {
    				C[n*j_idx + i_idx] += A[n*k_idx + i_idx]*B[k*j_idx + k_idx];
    			}
    		}
    	}
    }
    
    void matrix_multiply_neon(float32_t  *A, float32_t  *B, float32_t *C, uint32_t n, uint32_t m, uint32_t k) {
    	/* 
    	 * Multiply matrices A and B, store the result in C. 
    	 * It is the user's responsibility to make sure the matrices are compatible.
    	 */	
    
    	int A_idx;
    	int B_idx;
    	int C_idx;
    	
    	// these are the columns of a 4x4 sub matrix of A
    	float32x4_t A0;
    	float32x4_t A1;
    	float32x4_t A2;
    	float32x4_t A3;
    	
    	// these are the columns of a 4x4 sub matrix of B
    	float32x4_t B0;
    	float32x4_t B1;
    	float32x4_t B2;
    	float32x4_t B3;
    	
    	// these are the columns of a 4x4 sub matrix of C
    	float32x4_t C0;
    	float32x4_t C1;
    	float32x4_t C2;
    	float32x4_t C3;
    	
    	for (int i_idx=0; i_idx<n; i_idx+=4) {
    		for (int j_idx=0; j_idx<m; j_idx+=4) {
    			// Zero accumulators before matrix op
    			C0 = vmovq_n_f32(0);
    			C1 = vmovq_n_f32(0);
    			C2 = vmovq_n_f32(0);
    			C3 = vmovq_n_f32(0);
    			for (int k_idx=0; k_idx<k; k_idx+=4) {
    				// Compute base index to 4x4 block
    				A_idx = i_idx + n*k_idx;
    				B_idx = k*j_idx + k_idx;
    				
    				// Load most current A values in row 
    				A0 = vld1q_f32(A+A_idx);
    				A1 = vld1q_f32(A+A_idx+n);
    				A2 = vld1q_f32(A+A_idx+2*n);
    				A3 = vld1q_f32(A+A_idx+3*n);
    				
    				// Multiply accumulate in 4x1 blocks, i.e. each column in C
    				B0 = vld1q_f32(B+B_idx);
    				C0 = vfmaq_laneq_f32(C0, A0, B0, 0);
    				C0 = vfmaq_laneq_f32(C0, A1, B0, 1);
    				C0 = vfmaq_laneq_f32(C0, A2, B0, 2);
    				C0 = vfmaq_laneq_f32(C0, A3, B0, 3);
    				
    				B1 = vld1q_f32(B+B_idx+k);
    				C1 = vfmaq_laneq_f32(C1, A0, B1, 0);
    				C1 = vfmaq_laneq_f32(C1, A1, B1, 1);
    				C1 = vfmaq_laneq_f32(C1, A2, B1, 2);
    				C1 = vfmaq_laneq_f32(C1, A3, B1, 3);
    				
    				B2 = vld1q_f32(B+B_idx+2*k);
    				C2 = vfmaq_laneq_f32(C2, A0, B2, 0);
    				C2 = vfmaq_laneq_f32(C2, A1, B2, 1);
    				C2 = vfmaq_laneq_f32(C2, A2, B2, 2);
    				C2 = vfmaq_laneq_f32(C2, A3, B2, 3);
    				
    				B3 = vld1q_f32(B+B_idx+3*k);
    				C3 = vfmaq_laneq_f32(C3, A0, B3, 0);
    				C3 = vfmaq_laneq_f32(C3, A1, B3, 1);
    				C3 = vfmaq_laneq_f32(C3, A2, B3, 2);
    				C3 = vfmaq_laneq_f32(C3, A3, B3, 3);
    			}
    			// Compute base index for stores
    			C_idx = n*j_idx + i_idx;
    			vst1q_f32(C+C_idx, C0);
    			vst1q_f32(C+C_idx+n, C1);
    			vst1q_f32(C+C_idx+2*n, C2);
    			vst1q_f32(C+C_idx+3*n, C3);
    		}
    	}
    }
    
    void matrix_multiply_4x4_neon(float32_t *A, float32_t *B, float32_t *C) {
    	// these are the columns A
    	float32x4_t A0;
    	float32x4_t A1;
    	float32x4_t A2;
    	float32x4_t A3;
    	
    	// these are the columns B
    	float32x4_t B0;
    	float32x4_t B1;
    	float32x4_t B2;
    	float32x4_t B3;
    	
    	// these are the columns C
    	float32x4_t C0;
    	float32x4_t C1;
    	float32x4_t C2;
    	float32x4_t C3;
    	
    	A0 = vld1q_f32(A);
    	A1 = vld1q_f32(A+4);
    	A2 = vld1q_f32(A+8);
    	A3 = vld1q_f32(A+12);
    	
    	// Zero accumulators for C values
    	C0 = vmovq_n_f32(0);
    	C1 = vmovq_n_f32(0);
    	C2 = vmovq_n_f32(0);
    	C3 = vmovq_n_f32(0);
    	
    	// Multiply accumulate in 4x1 blocks, i.e. each column in C
    	B0 = vld1q_f32(B);
    	C0 = vfmaq_laneq_f32(C0, A0, B0, 0);
    	C0 = vfmaq_laneq_f32(C0, A1, B0, 1);
    	C0 = vfmaq_laneq_f32(C0, A2, B0, 2);
    	C0 = vfmaq_laneq_f32(C0, A3, B0, 3);
    	vst1q_f32(C, C0);
    	
    	B1 = vld1q_f32(B+4);
    	C1 = vfmaq_laneq_f32(C1, A0, B1, 0);
    	C1 = vfmaq_laneq_f32(C1, A1, B1, 1);
    	C1 = vfmaq_laneq_f32(C1, A2, B1, 2);
    	C1 = vfmaq_laneq_f32(C1, A3, B1, 3);
    	vst1q_f32(C+4, C1);
    	
    	B2 = vld1q_f32(B+8);
    	C2 = vfmaq_laneq_f32(C2, A0, B2, 0);
    	C2 = vfmaq_laneq_f32(C2, A1, B2, 1);
    	C2 = vfmaq_laneq_f32(C2, A2, B2, 2);
    	C2 = vfmaq_laneq_f32(C2, A3, B2, 3);
    	vst1q_f32(C+8, C2);
    	
    	B3 = vld1q_f32(B+12);
    	C3 = vfmaq_laneq_f32(C3, A0, B3, 0);
    	C3 = vfmaq_laneq_f32(C3, A1, B3, 1);
    	C3 = vfmaq_laneq_f32(C3, A2, B3, 2);
    	C3 = vfmaq_laneq_f32(C3, A3, B3, 3);
    	vst1q_f32(C+12, C3);
    }
    
    void print_matrix(float32_t *M, uint32_t cols, uint32_t rows) {
    	for (int i=0; i<rows; i++) {
    		for (int j=0; j<cols; j++) {
    			printf("%f ", M[j*rows + i]);
    		}
    		printf("\n");
    	}
    	printf("\n");
    }
    
    void matrix_init_rand(float32_t *M, uint32_t numvals) {
    	for (int i=0; i<numvals; i++) {
    		M[i] = (float)rand()/(float)(RAND_MAX);
    	}
    }
    
    void matrix_init(float32_t *M, uint32_t cols, uint32_t rows, float32_t val) {
    	for (int i=0; i<rows; i++) {
    		for (int j=0; j<cols; j++) {
    			M[j*rows + i] = val;
    		}
    	}
    }
    
    bool f32comp_noteq(float32_t a, float32_t b) {
    	if (fabs(a-b) < 0.000001) {
    		return false;
    	}
    	return true;
    }
    
    bool matrix_comp(float32_t *A, float32_t *B, uint32_t rows, uint32_t cols) {
    	float32_t a;
    	float32_t b;
    	for (int i=0; i<rows; i++) {
    		for (int j=0; j<cols; j++) {
    			a = A[rows*j + i];
    			b = B[rows*j + i];	
    			
    			if (f32comp_noteq(a, b)) {
    				printf("i=%d, j=%d, A=%f, B=%f\n", i, j, a, b);
    				return false;
    			}
    		}
    	}
    	return true;
    }
    
    int main() {
    	uint32_t n = 2*BLOCK_SIZE; // rows in A
    	uint32_t m = 2*BLOCK_SIZE; // cols in B
    	uint32_t k = 2*BLOCK_SIZE; // cols in a and rows in b
    	
    	float32_t A[n*k];
    	float32_t B[k*m];
    	float32_t C[n*m];
    	float32_t D[n*m];
    	float32_t E[n*m];
    	
    	bool c_eq_asm;
    	bool c_eq_neon;
    
    	matrix_init_rand(A, n*k);
    	matrix_init_rand(B, k*m);
    	matrix_init(C, n, m, 0);
    
    	print_matrix(A, k, n);
    	print_matrix(B, m, k);
    	//print_matrix(C, n, m);
    	
    	matrix_multiply_c(A, B, E, n, m, k);
    	printf("C\n");
    	print_matrix(E, n, m);
    	printf("===============================\n");
    	
    	matrix_multiply_neon(A, B, D, n, m, k);
    	printf("Neon\n");
    	print_matrix(D, n, m);
    	c_eq_neon = matrix_comp(E, D, n, m);
    	printf("Neon equal to C? %d\n", c_eq_neon);
    	printf("===============================\n");
    }

The full source code above can be compiled and disassembled on an Arm machine using the following commands:

gcc -g -o3 matrix.c -o exe_matrix_o3
objdump -d exe_ matrix _o3 > disasm_matrix_o3

If you don't have access to Arm-based hardware, you can use Arm DS-5 Community Edition and the Armv8-A Foundation Platform.

Previous Next