Can I use the AVX FMA units to do bit-exact 52 bit integer multiplications?
Solution 1:
Yes it's possible. But as of AVX2, it's unlikely to be better than the scalar approaches with MULX/ADCX/ADOX.
There's virtually an unlimited number of variations of this approach for different input/output domains. I'll only cover 3 of them, but they are easy to generalize once you know how they work.
Disclaimers:
- All solutions here assume the rounding mode is round-to-even.
- Use of fast-math optimization flags is not recommended as these solutions rely on strict IEEE.
Signed doubles in the range: [-251, 251]
// A*B = L + H*2^52
// Input: A and B are in the range [-2^51, 2^51]
// Output: L and H are in the range [-2^51, 2^51]
void mul52_signed(__m256d& L, __m256d& H, __m256d A, __m256d B){
const __m256d ROUND = _mm256_set1_pd(30423614405477505635920876929024.); // 3 * 2^103
const __m256d SCALE = _mm256_set1_pd(1. / 4503599627370496); // 1 / 2^52
// Multiply and add normalization constant. This forces the multiply
// to be rounded to the correct number of bits.
H = _mm256_fmadd_pd(A, B, ROUND);
// Undo the normalization.
H = _mm256_sub_pd(H, ROUND);
// Recover the bottom half of the product.
L = _mm256_fmsub_pd(A, B, H);
// Correct the scaling of H.
H = _mm256_mul_pd(H, SCALE);
}
This is the simplest one and the only one which is competitive with the scalar approaches. The final scaling is optional depending on what you want to do with the outputs. So this can be considered only 3 instructions. But it's also the least useful since both the inputs and outputs are floating-point values.
It is absolutely critical that both the FMAs stay fused. And this is where fast-math optimizations can break things. If the first FMA is broken up, then L
is no longer guaranteed to be in the range [-2^51, 2^51]
. If the second FMA is broken up, L
will be completely wrong.
Signed integers in the range: [-251, 251]
// A*B = L + H*2^52
// Input: A and B are in the range [-2^51, 2^51]
// Output: L and H are in the range [-2^51, 2^51]
void mul52_signed(__m256i& L, __m256i& H, __m256i A, __m256i B){
const __m256d CONVERT_U = _mm256_set1_pd(6755399441055744); // 3*2^51
const __m256d CONVERT_D = _mm256_set1_pd(1.5);
__m256d l, h, a, b;
// Convert to double
A = _mm256_add_epi64(A, _mm256_castpd_si256(CONVERT_U));
B = _mm256_add_epi64(B, _mm256_castpd_si256(CONVERT_D));
a = _mm256_sub_pd(_mm256_castsi256_pd(A), CONVERT_U);
b = _mm256_sub_pd(_mm256_castsi256_pd(B), CONVERT_D);
// Get top half. Convert H to int64.
h = _mm256_fmadd_pd(a, b, CONVERT_U);
H = _mm256_sub_epi64(_mm256_castpd_si256(h), _mm256_castpd_si256(CONVERT_U));
// Undo the normalization.
h = _mm256_sub_pd(h, CONVERT_U);
// Recover bottom half.
l = _mm256_fmsub_pd(a, b, h);
// Convert L to int64
l = _mm256_add_pd(l, CONVERT_D);
L = _mm256_sub_epi64(_mm256_castpd_si256(l), _mm256_castpd_si256(CONVERT_D));
}
Building off of the first example, we combine it with a generalized version of the fast double <-> int64
conversion trick.
This one is more useful since you're working with integers. But even with the fast conversion trick, most of the time will be spent doing conversions. Fortunately, you can eliminate some of the input conversions if you are multiplying by the same operand multiple times.
Unsigned integers in the range: [0, 252)
// A*B = L + H*2^52
// Input: A and B are in the range [0, 2^52)
// Output: L and H are in the range [0, 2^52)
void mul52_unsigned(__m256i& L, __m256i& H, __m256i A, __m256i B){
const __m256d CONVERT_U = _mm256_set1_pd(4503599627370496); // 2^52
const __m256d CONVERT_D = _mm256_set1_pd(1);
const __m256d CONVERT_S = _mm256_set1_pd(1.5);
__m256d l, h, a, b;
// Convert to double
A = _mm256_or_si256(A, _mm256_castpd_si256(CONVERT_U));
B = _mm256_or_si256(B, _mm256_castpd_si256(CONVERT_D));
a = _mm256_sub_pd(_mm256_castsi256_pd(A), CONVERT_U);
b = _mm256_sub_pd(_mm256_castsi256_pd(B), CONVERT_D);
// Get top half. Convert H to int64.
h = _mm256_fmadd_pd(a, b, CONVERT_U);
H = _mm256_xor_si256(_mm256_castpd_si256(h), _mm256_castpd_si256(CONVERT_U));
// Undo the normalization.
h = _mm256_sub_pd(h, CONVERT_U);
// Recover bottom half.
l = _mm256_fmsub_pd(a, b, h);
// Convert L to int64
l = _mm256_add_pd(l, CONVERT_S);
L = _mm256_sub_epi64(_mm256_castpd_si256(l), _mm256_castpd_si256(CONVERT_S));
// Make Correction
H = _mm256_sub_epi64(H, _mm256_srli_epi64(L, 63));
L = _mm256_and_si256(L, _mm256_set1_epi64x(0x000fffffffffffff));
}
Finally we get the answer to the original question. This builds off of the signed integer solution by adjusting the conversions and adding a correction step.
But at this point, we're at 13 instructions - half of which are high-latency instructions, not counting the numerous FP <-> int
bypass delays. So it's unlikely this will be winning any benchmarks. By comparison, a 64 x 64 -> 128-bit
SIMD multiply can be done in 16 instructions (14 if you pre-process the inputs.)
The correction step can be omitted if the rounding mode is round-down or round-to-zero. The only instruction where this matters is h = _mm256_fmadd_pd(a, b, CONVERT_U);
. So on AVX512, you can override the rounding for that instruction and leave the rounding mode alone.
Final Thoughts:
It's worth noting that the 252 range of operation can be reduced by adjusting the magic constants. This may be useful for the first solution (the floating-point one) since it gives you extra mantissa to use for accumulation in floating-point. This lets you bypass the need to constantly to convert back-and-forth between int64 and double like in the last 2 solutions.
While the 3 examples here are unlikely to be better than scalar methods, AVX512 will almost certainly tip the balance. Knights Landing in particular has poor throughput for ADCX and ADOX.
And of course all of this is moot when AVX512-IFMA comes out. That reduces a full 52 x 52 -> 104-bit
product to 2 instructions and gives the accumulation for free.
Solution 2:
One way to do multi-word integer arithmetic is with double-double arithmetic. Let's start with some double-double multiplication code
#include <math.h>
typedef struct {
double hi;
double lo;
} doubledouble;
static doubledouble quick_two_sum(double a, double b) {
double s = a + b;
double e = b - (s - a);
return (doubledouble){s, e};
}
static doubledouble two_prod(double a, double b) {
double p = a*b;
double e = fma(a, b, -p);
return (doubledouble){p, e};
}
doubledouble df64_mul(doubledouble a, doubledouble b) {
doubledouble p = two_prod(a.hi, b.hi);
p.lo += a.hi*b.lo;
p.lo += a.lo*b.hi;
return quick_two_sum(p.hi, p.lo);
}
The function two_prod
can do integer 53bx53b -> 106b in two instructions. The function df64_mul
can do integer 106bx106b -> 106b.
Let's compare this to integer 128bx128b -> 128b with integer hardware.
__int128 mul128(__int128 a, __int128 b) {
return a*b;
}
The assembly for mul128
imul rsi, rdx
mov rax, rdi
imul rcx, rdi
mul rdx
add rcx, rsi
add rdx, rcx
The assembly for df64_mul
(compiled with gcc -O3 -S i128.c -masm=intel -mfma -ffp-contract=off
)
vmulsd xmm4, xmm0, xmm2
vmulsd xmm3, xmm0, xmm3
vmulsd xmm1, xmm2, xmm1
vfmsub132sd xmm0, xmm4, xmm2
vaddsd xmm3, xmm3, xmm0
vaddsd xmm1, xmm3, xmm1
vaddsd xmm0, xmm1, xmm4
vsubsd xmm4, xmm0, xmm4
vsubsd xmm1, xmm1, xmm4
mul128
does three scalar multiplications and two scalar additions/subtractions whereas df64_mul
does 3 SIMD multiplications, 1 SIMD FMA, and 5 SIMD additions/subtractions. I have not profiled these methods but it does not seem unreasonable to me that df64_mul
could outperform mul128
using 4-doubles per AVX register (change sd
to pd
and xmm
to ymm
).
It's tempting to say that the problem is switching back to the integer domain. But why is this necessary? You can do everything in the floating point domain. Let's look at some examples. I find it easier to unit test with float
than with double
.
doublefloat two_prod(float a, float b) {
float p = a*b;
float e = fma(a, b, -p);
return (doublefloat){p, e};
}
//3202129*4807935=15395628093615
x = two_prod(3202129,4807935)
int64_t hi = p, lo = e, s = hi+lo
//p = 1.53956280e+13, e = 1.02575000e+05
//hi = 15395627991040, lo = 102575, s = 15395628093615
//1450779*1501672=2178594202488
y = two_prod(1450779, 1501672)
int64_t hi = p, lo = e, s = hi+lo
//p = 2.17859424e+12, e = -4.00720000e+04
//hi = 2178594242560 lo = -40072, s = 2178594202488
So we end up with different ranges and in the second case the error (e
) is even negative but the sum is still correct. We could even add the two doublefloat values x
and y
together (once we know how to do double-double addition - see the code at the end) and get 15395628093615+2178594202488
. There is no need to normalize the results.
But addition brings up the main problem with double-double arithmetic. Namely, addition/subtraction is slow e.g. 128b+128b -> 128b needs at least 11 floating point additions whereas with integers it only needs two (add
and adc
).
So if an algorithm is heavy on multiplication but light on addition then doing multi-word integer operations with double-double could win.
As a side note the C language is flexible enough to allow for an implementation where integers are implemented entirely through floating point hardware. int
could be 24-bits (from single floating point), long
could be 54-bits. (from double floating point), and long long
could be 106-bits (from double-double). C does not even require two's compliment and therefore integers could use signed magnitude for negative numbers as is usual with floating point.
Here is working C code with double-double multiplication and addition (I have not implemented division or other operations such as sqrt
but there are papers showing how to do this) in case somebody wants to play with it. It would be interesting to see if this could be optimized for integers.
//if compiling with -mfma you must also use -ffp-contract=off
//float-float is easier to debug. If you want double-double replace
//all float words with double and fmaf with fma
#include <stdio.h>
#include <math.h>
#include <inttypes.h>
#include <x86intrin.h>
#include <stdlib.h>
//#include <float.h>
typedef struct {
float hi;
float lo;
} doublefloat;
typedef union {
float f;
int i;
struct {
unsigned mantisa : 23;
unsigned exponent: 8;
unsigned sign: 1;
};
} float_cast;
void print_float(float_cast a) {
printf("%.8e, 0x%x, mantisa 0x%x, exponent 0x%x, expondent-127 %d, sign %u\n", a.f, a.i, a.mantisa, a.exponent, a.exponent-127, a.sign);
}
void print_doublefloat(doublefloat a) {
float_cast hi = {a.hi};
float_cast lo = {a.lo};
printf("hi: "); print_float(hi);
printf("lo: "); print_float(lo);
}
doublefloat quick_two_sum(float a, float b) {
float s = a + b;
float e = b - (s - a);
return (doublefloat){s, e};
// 3 add
}
doublefloat two_sum(float a, float b) {
float s = a + b;
float v = s - a;
float e = (a - (s - v)) + (b - v);
return (doublefloat){s, e};
// 6 add
}
doublefloat df64_add(doublefloat a, doublefloat b) {
doublefloat s, t;
s = two_sum(a.hi, b.hi);
t = two_sum(a.lo, b.lo);
s.lo += t.hi;
s = quick_two_sum(s.hi, s.lo);
s.lo += t.lo;
s = quick_two_sum(s.hi, s.lo);
return s;
// 2*two_sum, 2 add, 2*quick_two_sum = 2*6 + 2 + 2*3 = 20 add
}
doublefloat split(float a) {
//#define SPLITTER (1<<27) + 1
#define SPLITTER (1<<12) + 1
float t = (SPLITTER)*a;
float hi = t - (t - a);
float lo = a - hi;
return (doublefloat){hi, lo};
// 1 mul, 3 add
}
doublefloat split_sse(float a) {
__m128 k = _mm_set1_ps(4097.0f);
__m128 a4 = _mm_set1_ps(a);
__m128 t = _mm_mul_ps(k,a4);
__m128 hi4 = _mm_sub_ps(t,_mm_sub_ps(t, a4));
__m128 lo4 = _mm_sub_ps(a4, hi4);
float tmp[4];
_mm_storeu_ps(tmp, hi4);
float hi = tmp[0];
_mm_storeu_ps(tmp, lo4);
float lo = tmp[0];
return (doublefloat){hi,lo};
}
float mult_sub(float a, float b, float c) {
doublefloat as = split(a), bs = split(b);
//print_doublefloat(as);
//print_doublefloat(bs);
return ((as.hi*bs.hi - c) + as.hi*bs.lo + as.lo*bs.hi) + as.lo*bs.lo;
// 4 mul, 4 add, 2 split = 6 mul, 10 add
}
doublefloat two_prod(float a, float b) {
float p = a*b;
float e = mult_sub(a, b, p);
return (doublefloat){p, e};
// 1 mul, one mult_sub
// 7 mul, 10 add
}
float mult_sub2(float a, float b, float c) {
doublefloat as = split(a);
return ((as.hi*as.hi -c ) + 2*as.hi*as.lo) + as.lo*as.lo;
}
doublefloat two_sqr(float a) {
float p = a*a;
float e = mult_sub2(a, a, p);
return (doublefloat){p, e};
}
doublefloat df64_mul(doublefloat a, doublefloat b) {
doublefloat p = two_prod(a.hi, b.hi);
p.lo += a.hi*b.lo;
p.lo += a.lo*b.hi;
return quick_two_sum(p.hi, p.lo);
//two_prod, 2 add, 2mul, 1 quick_two_sum = 9 mul, 15 add
//or 1 mul, 1 fma, 2add 2mul, 1 quick_two_sum = 3 mul, 1 fma, 5 add
}
doublefloat df64_sqr(doublefloat a) {
doublefloat p = two_sqr(a.hi);
p.lo += 2*a.hi*a.lo;
return quick_two_sum(p.hi, p.lo);
}
int float2int(float a) {
int M = 0xc00000; //1100 0000 0000 0000 0000 0000
a += M;
float_cast x;
x.f = a;
return x.i - 0x4b400000;
}
doublefloat add22(doublefloat a, doublefloat b) {
float r = a.hi + b.hi;
float s = fabsf(a.hi) > fabsf(b.hi) ?
(((a.hi - r) + b.hi) + b.lo ) + a.lo :
(((b.hi - r) + a.hi) + a.lo ) + b.lo;
return two_sum(r, s);
//11 add
}
int main(void) {
//print_float((float_cast){1.0f});
//print_float((float_cast){-2.0f});
//print_float((float_cast){0.0f});
//print_float((float_cast){3.14159f});
//print_float((float_cast){1.5f});
//print_float((float_cast){3.0f});
//print_float((float_cast){7.0f});
//print_float((float_cast){15.0f});
//print_float((float_cast){31.0f});
//uint64_t t = 0xffffff;
//print_float((float_cast){1.0f*t});
//printf("%" PRId64 " %" PRIx64 "\n", t*t,t*t);
/*
float_cast t1;
t1.mantisa = 0x7fffff;
t1.exponent = 0xfe;
t1.sign = 0;
print_float(t1);
*/
//doublefloat z = two_prod(1.0f*t, 1.0f*t);
//print_doublefloat(z);
//double z2 = (double)z.hi + (double)z.lo;
//printf("%.16e\n", z2);
doublefloat s = {0};
int64_t si = 0;
for(int i=0; i<100000; i++) {
int ai = rand()%0x800, bi = rand()%0x800000;
float a = ai, b = bi;
doublefloat z = two_prod(a,b);
int64_t zi = (int64_t)ai*bi;
//print_doublefloat(z);
//s = df64_add(s,z);
s = add22(s,z);
si += zi;
print_doublefloat(z);
printf("%d %d ", ai,bi);
int64_t h = z.hi;
int64_t l = z.lo;
int64_t t = h+l;
//if(t != zi) printf("%" PRId64 " %" PRId64 "\n", h, l);
printf("%" PRId64 " %" PRId64 " %" PRId64 " %" PRId64 "\n", zi, h, l, h+l);
h = s.hi;
l = s.lo;
t = h + l;
//if(si != t) printf("%" PRId64 " %" PRId64 "\n", h, l);
if(si > (1LL<<48)) {
printf("overflow after %d iterations\n", i); break;
}
}
print_doublefloat(s);
printf("%" PRId64 "\n", si);
int64_t x = s.hi;
int64_t y = s.lo;
int64_t z = x+y;
//int hi = float2int(s.hi);
printf("%" PRId64 " %" PRId64 " %" PRId64 "\n", z,x,y);
}
Solution 3:
Well, you certainly can do FP-lane operations on things that are integers. And they will always be exact: While there are SSE instructions that do not guarantee proper IEEE-754 precision and rounding, without exception they are the ones which do not have an integer range, so not the ones you're looking at anyway. Bottom line: Addition/subtraction/multiplication will always be exact in the integer domain, even if you're doing them on packed floats.
As for quad-precision floats (>52 bit mantissa), no, those aren't supported, and likely won't be in the foreseeable future. Just not much call for them. They showed up in a few SPARC-era workstation architectures, but honestly they were just a bandage over developers' incomplete understanding of how to write numerically stable algorithms, and over time they faded out.
Wide-integer operations turn out to be a really bad fit for SSE. I really tried to leverage it recently when I was implementing a big-integer library, and honestly it did me no good. x86 was designed for multi-word arithmetic; you can see it in operations such as ADC (which produces and consumes a carry bit) and IDIV (which allows the divisor to be twice as wide as the dividend as long as the quotient is no wider than the dividend, a constraint that makes it useless for anything but multiword division). But multiword arithmetic is by nature sequential, and SSE is by nature parallel. If you're lucky enough that your numbers have just enough bits to fit into a FP mantissa, congratulations. But if you have big integers in general, SSE is probably not going to be your friend.