Find largest element in matrix and its column and row indexes using SSE and AVX

I need to find the largest element in 1d matrix and its column and row indexes.

I use 1d matrix, so just finding the max element's index is needed first and then it is easy to get row and column.

My problem is that I cannot get that index.

I have a working function that finds largest element and uses SSE, here it is:

float find_largest_element_in_matrix_SSE(float* m, unsigned const int dims)
{
    size_t i;
    int index = -1;
    __m128 max_el = _mm_loadu_ps(m);
    __m128 curr;

    for (i = 4; i < dims * dims; i += 4)
    {
        curr = _mm_loadu_ps(m + i);
        max_el = _mm_max_ps(max_el, curr);
    }

    __declspec(align(16))float max_v[4] = { 0 };
    _mm_store_ps(max_v, max_el);

    return max(max(max(max_v[0], max_v[1]), max_v[2]), max_v[3]);
}

and also I have a non-working function that uses AVX:

float find_largest_element_in_matrix_AVX(float* m, unsigned const int dims)
{
    size_t i;
    int index = -1;
    __m256 max_el = _mm256_loadu_ps(m);
    __m256 curr;

    for (i = 8; i < dims * dims; i += 8)
    {
        curr = _mm256_loadu_ps(m + i);
        max_el = _mm256_max_ps(max_el, curr);
    }

    __declspec(align(32))float max_v[8] = { 0 };
    _mm256_store_ps(max_v, max_el);

    __m256 y = _mm256_permute2f128_ps(max_el, max_el, 1);
    __m256 m1 = _mm256_max_ps(max_el, y);m1[1] = max(max_el[1], max_el[3])
    __m256 m2 = _mm256_permute_ps(m1, 5); 
    __m256 m_res = _mm256_max_ps(m1, m2); 

    return m[0];
}

Could anyone help me with actually finding the index of the max element and make my AVX version work?


Here's a working SSE (SSE 4) implementation that returns the max val and corresponding index, along with a scalar reference implementation and test harness:

#include <stdio.h>
#include <stdint.h>
#include <stdlib.h>
#include <time.h>
#include <smmintrin.h>  // SSE 4.1

float find_largest_element_in_matrix_ref(const float* m, int dims, int *maxIndex)
{
    float maxVal = m[0];
    int i;

    *maxIndex = 0;

    for (i = 1; i < dims * dims; ++i)
    {
        if (m[i] > maxVal)
        {
            maxVal = m[i];
            *maxIndex = i;
        }
    }
    return maxVal;
}

float find_largest_element_in_matrix_SSE(const float* m, int dims, int *maxIndex)
{
    float maxVal = m[0];
    float aMaxVal[4];
    int32_t aMaxIndex[4];
    int i;

    *maxIndex = 0;

    const __m128i vIndexInc = _mm_set1_epi32(4);
    __m128i vMaxIndex = _mm_setr_epi32(0, 1, 2, 3);
    __m128i vIndex = vMaxIndex;
    __m128 vMaxVal = _mm_loadu_ps(m);

    for (i = 4; i < dims * dims; i += 4)
    {
        __m128 v = _mm_loadu_ps(&m[i]);
        __m128 vcmp = _mm_cmpgt_ps(v, vMaxVal);
        vIndex = _mm_add_epi32(vIndex, vIndexInc);
        vMaxVal = _mm_max_ps(vMaxVal, v);
        vMaxIndex = _mm_blendv_epi8(vMaxIndex, vIndex, _mm_castps_si128(vcmp));
    }
    _mm_storeu_ps(aMaxVal, vMaxVal);
    _mm_storeu_si128((__m128i *)aMaxIndex, vMaxIndex);
    maxVal = aMaxVal[0];
    *maxIndex = aMaxIndex[0];
    for (i = 1; i < 4; ++i)
    {
        if (aMaxVal[i] > maxVal)
        {
            maxVal = aMaxVal[i];
            *maxIndex = aMaxIndex[i];
        }
    }
    return maxVal;
}

int main()
{
    const int dims = 1024;
    float m[dims * dims];
    float maxVal_ref, maxVal_SSE;
    int maxIndex_ref = -1, maxIndex_SSE = -1;
    int i;

    srand(time(NULL));

    for (i = 0; i < dims * dims; ++i)
    {
        m[i] = (float)rand() / RAND_MAX;
    }

    maxVal_ref = find_largest_element_in_matrix_ref(m, dims, &maxIndex_ref);
    maxVal_SSE = find_largest_element_in_matrix_SSE(m, dims, &maxIndex_SSE);

    if (maxVal_ref == maxVal_SSE && maxIndex_ref == maxIndex_SSE)
    {
        printf("PASS: maxVal = %f, maxIndex = %d\n",
                      maxVal_ref, maxIndex_ref);
    }
    else
    {
        printf("FAIL: maxVal_ref = %f, maxVal_SSE = %f, maxIndex_ref = %d, maxIndex_SSE = %d\n",
                      maxVal_ref, maxVal_SSE, maxIndex_ref, maxIndex_SSE);
    }
    return 0;
}

Compile and run:

$ gcc -Wall -msse4 Yakovenko.c && ./a.out 
PASS: maxVal = 0.999999, maxIndex = 120409

Obviously you can get the row and column indices if needed:

int rowIndex = maxIndex / dims;
int colIndex = maxIndex % dims;

From here it should be fairly straightforward to write an AVX2 implementation.


One approach would be to calculate maximum in the first pass, and find the index by linear search in the second pass. Here is a sample implementation in SSE2:

#define anybit __builtin_ctz   //or lookup table with 16 entries...
float find_largest_element_in_matrix_SSE(const float* m, int dims, int *maxIndex) {
    //first pass: calculate maximum as usual
    __m128 vMaxVal = _mm_loadu_ps(m);
    for (int i = 4; i < dims * dims; i += 4)
        vMaxVal = _mm_max_ps(vMaxVal, _mm_loadu_ps(&m[i]));
    //perform in-register reduction
    vMaxVal = _mm_max_ps(vMaxVal, _mm_shuffle_ps(vMaxVal, vMaxVal, _MM_SHUFFLE(2, 3, 0, 1)));
    vMaxVal = _mm_max_ps(vMaxVal, _mm_shuffle_ps(vMaxVal, vMaxVal, _MM_SHUFFLE(1, 0, 3, 2)));
    //second pass: search for maximal value
    for (int i = 0; i < dims * dims; i += 4) {
        __m128 vIsMax = _mm_cmpeq_ps(vMaxVal, _mm_loadu_ps(&m[i]));
        if (int mask = _mm_movemask_ps(vIsMax)) {
            *maxIndex = i + anybit(mask);
            return _mm_cvtss_f32(vMaxVal);
        }
    }
}

Note that the branch in the second loop should be almost perfectly predicted unless your input data is very small.

The solution suffers from several problems, notably:

  1. It may work incorrectly in presence of weird floating point values, e.g. with NaNs.

  2. If your matrix does not fit into CPU cache, then the code would read the matrix twice from the main memory, so it would be two times slower than the single-pass approach. This can be solved for large matrices by block-wise processing.

  3. In the first loop each iteration depends on the previous one (vMaxVal is both modified and read) so it would be slowed down by latency of _mm_max_ps. Perhaps it would be great to unroll the first loop a bit (2x or 4x), while having 4 independent registers for vMaxVal (actually, the second loop would also benefit from unrolling).

Porting to AVX should be pretty straight-forward, except for the in-register reduction:

vMaxVal = _mm256_max_ps(vMaxVal, _mm256_shuffle_ps(vMaxVal, vMaxVal, _MM_SHUFFLE(2, 3, 0, 1)));
vMaxVal = _mm256_max_ps(vMaxVal, _mm256_shuffle_ps(vMaxVal, vMaxVal, _MM_SHUFFLE(1, 0, 3, 2)));
vMaxVal = _mm256_max_ps(vMaxVal, _mm256_permute2f128_ps(vMaxVal, vMaxVal, 1));

yet another approach:

void find_largest_element_in_matrix_SSE(float * matrix, size_t n, int * row, int * column, float * v){

    __m128 indecies = _mm_setr_ps(0, 1, 2, 3);
    __m128 update = _mm_setr_ps(4, 4, 4, 4);
    __m128 max_indecies = _mm_setr_ps(0, 1, 2, 3);
    __m128 max = _mm_load_ps(matrix);
    for (int i = 4; i < n * n; i+=4){
        indecies = _mm_add_ps(indecies, update);
        __m128 pm2 = _mm_load_ps(&matrix[i]);
        __m128 mask = _mm_cmpge_ps(max, pm2);
        max = _mm_max_ps(max, pm2);
        max_indecies = _mm_or_ps(_mm_and_ps(max_indecies, mask), _mm_andnot_ps(mask, indecies));
    }
    __declspec (align(16)) int max_ind[4];
    __m128i maxi = _mm_cvtps_epi32(max_indecies);
    _mm_store_si128((__m128i *) max_ind, maxi);
    int c = max_ind[0];
    for (int i = 1; i < 4; i++)
        if (matrix[max_ind[i]] >= matrix[c] && max_ind[i] < c){
            c = max_ind[i];
        }

    *v = matrix[c];
    *row = c / n;
    *column = c % n;
}

void find_largest_element_in_matrix_AVX(float * matrix, size_t n, int * row,  int * column, float * v){
    __m256 indecies = _mm256_setr_ps(0, 1, 2, 3, 4, 5, 6, 7);
    __m256 update = _mm256_setr_ps(8, 8, 8, 8, 8, 8, 8, 8);
    __m256 max_indecies = _mm256_setr_ps(0, 1, 2, 3, 4, 5, 6, 7);
    __m256 max = _mm256_load_ps(matrix);

    for (int i = 8; i < n * n; i += 8){
        indecies = _mm256_add_ps(indecies, update);
        __m256 pm2 = _mm256_load_ps(&matrix[i]);
        __m256 mask = _mm256_cmp_ps(max, pm2, _CMP_GE_OQ);
        max = _mm256_max_ps(max, pm2);
        max_indecies = _mm256_or_ps(_mm256_and_ps(max_indecies, mask), _mm256_andnot_ps(mask, indecies));
    }
    __declspec (align(32)) int max_ind[8];
    __m256i maxi = _mm256_cvtps_epi32(max_indecies);

    _mm256_store_si256((__m256i *) max_ind, maxi);

    int c = max_ind[0];
    for (int i = 1; i < 8; i++)
        if (matrix[max_ind[i]] >= matrix[c] && max_ind[i] < c){
            c = max_ind[i];
        }

    *v = matrix[c];
    *row = c / n;
    *column = c % n;
}