Efficient implementation of log2(__m256d) in AVX2
The usual strategy is based on the identity log(a*b) = log(a) + log(b)
, or in this case log2( 2^exponent * mantissa) ) = log2( 2^exponent ) + log2(mantissa)
. Or simplifying, exponent + log2(mantissa)
. The mantissa has a very limited range, 1.0 to 2.0, so a polynomial for log2(mantissa)
only has to fit over that very limited range. (Or equivalently, mantissa = 0.5 to 1.0, and change the exponent bias-correction constant by 1).
A Taylor series expansion is a good starting point for the coefficients, but you usually want to minimize the max-absolute-error (or relative error) over that specific range, and Taylor series coefficients likely leave have a lower or higher outlier over that range, rather than having the max positive error nearly matching the max negative error. So you can do what's called a minimax fit of the coefficients.
If it's important that your function evaluates log2(1.0)
to exactly 0.0
, you can arrange for that to happen by actually using mantissa-1.0
as your polynomial, and no constant coefficient. 0.0 ^ n = 0.0
. This greatly improves the relative error for inputs near 1.0 as well, even if the absolute error is still small.
How accurate do you need it to be, and over what range of inputs? As usual there's a tradeoff between accuracy and speed, but fortunately it's pretty easy to move along that scale by e.g. adding one more polynomial term (and re-fitting the coefficients), or by dropping some rounding-error avoidance.
Agner Fog's VCL implementation of log_d()
aims for very high accuracy, using tricks to avoid rounding error by avoiding things that might result in adding a small and a large number when possible. This obscures the basic design somewhat.
For a faster more approximate float
log()
, see the polynomial implementation on http://jrfonseca.blogspot.ca/2008/09/fast-sse2-pow-tables-or-polynomials.html. It leaves out a LOT of the extra precision-gaining tricks that VCL uses, so it's easier to understand. It uses a polynomial approximation for the mantissa over the 1.0 to 2.0 range.
(That's the real trick to log()
implementations: you only need a polynomial that works over a small range.)
It already just does log2
instead of log
, unlike VCL's where the log-base-e is baked in to the constants and how it uses them. Reading it is probably a good starting point for understanding exponent + polynomial(mantissa)
implementations of log()
.
Even the highest-precision version of it is not full float
precision, let alone double
, but you could fit a polynomial with more terms. Or apparently a ratio of two polynomials works well; that's what VCL uses for double
.
I got excellent results from porting JRF's SSE2 function to AVX2 + FMA (and especially AVX512 with _mm512_getexp_ps
and _mm512_getmant_ps
), once I tuned it carefully. (It was part of a commercial project, so I don't think I can post the code.) A fast approximate implementation for float
was exactly what I wanted.
In my use-case, each jrf_fastlog()
was independent, so OOO execution nicely hid the FMA latency, and it wasn't even worth using the higher-ILP shorter-latency polynomial evaluation method that VCL's polynomial_5()
function uses ("Estrin's scheme", which does some non-FMA multiplies before the FMAs, resulting in more total instructions).
Agner Fog's VCL is now Apache-licensed, so any project can just include it directly. If you want high accuracy, you should just use VCL directly. It's header-only, just inline functions, so it won't bloat your binary.
VCL's log
float and double functions are in vectormath_exp.h
. There are two main parts to the algorithm:
extract the exponent bits and convert that integer back into a float (after adjusting for the bias that IEEE FP uses).
-
extract the mantissa and OR in some exponent bits to get a vector of
double
values in the[0.5, 1.0)
range. (Or(0.5, 1.0]
, I forget).Further adjust this with
if(mantissa <= SQRT2*0.5) { mantissa += mantissa; exponent++;}
, and thenmantissa -= 1.0
.Use a polynomial approximation to
log(x)
that is accurate around x=1.0. (Fordouble
, VCL'slog_d()
uses a ratio of two 5th-order polynomials. @harold says this is often good for precision. One division mixed in with a lot of FMAs doesn't usually hurt throughput, but it does have higher latency than an FMA. Usingvrcpps
+ a Newton-Raphson iteration is typically slower than just usingvdivps
on modern hardware. Using a ratio also creates more ILP by evaluating two lower-order polynomials in parallel, instead of one high-order polynomial, and may lower overall latency vs. one long dep chain for a high-order polynomial (which would also accumulate significant rounding error along that one long chain).
Then add exponent + polynomial_approx_log(mantissa)
to get the final log() result. VCL does this in multiple steps to reduce rounding error. ln2_lo + ln2_hi = ln(2)
. It's split up into a small and a large constant to reduce rounding error.
// res is the polynomial(adjusted_mantissa) result
// fe is the float exponent
// x is the adjusted_mantissa. x2 = x*x;
res = mul_add(fe, ln2_lo, res); // res += fe * ln2_lo;
res += nmul_add(x2, 0.5, x); // res += x - 0.5 * x2;
res = mul_add(fe, ln2_hi, res); // res += fe * ln2_hi;
You can drop the 2-step ln2
stuff and just use VM_LN2
if you aren't aiming for 0.5 or 1 ulp accuracy (or whatever this function actually provide; IDK.)
The x - 0.5*x2
part is really an extra polynomial term, I guess. This is what I meant by log base e being baked-in: you'd need a coefficient on those terms, or to get rid of that line and re-fit the polynomial coefficients for log2. You can't just multiply all the polynomial coefficients by a constant.
After that, it checks for underflow, overflow or denormal, and branches if any element in the vector needs special processing to produce a proper NaN or -Inf rather than whatever garbage we got from the polynomial + exponent. If your values are known to be finite and positive, you can comment out this part and get a significant speedup (even the checking before the branch takes several instructions).
Further reading:
http://gallium.inria.fr/blog/fast-vectorizable-math-approx/ some stuff about how to evaluate relative and absolute error in a polynomial approximation, and doing a minimax fix of the coefficients instead of just using a Taylor series expansion.
-
http://www.machinedlearnings.com/2011/06/fast-approximate-logarithm-exponential.html an interesting approach: it type-puns a
float
touint32_t
, and converts that integer tofloat
. Since IEEE binary32 floats store the exponent in higher bits than the mantissa, the resultingfloat
mostly represents the value of the exponent, scaled by1 << 23
, but also containing information from the mantissa.Then it uses an expression with a couple coefficients to fix things up and get a
log()
approximation. It includes a division by(constant + mantissa)
to correct for the mantissa pollution when converting the float bit-pattern tofloat
. I found that a vectorized version of that was slower and less accurate with AVX2 on HSW and SKL than JRF fastlog with 4th-order polynomials. (Especially when using it as part of a fastarcsinh
which also uses the divide unit forvsqrtps
.)
Finally here is my best result which on Ryzen 1800X @3.6GHz gives about 0.8 billion of logarithms per second (200 million vectors of 4 logarithms in each) in a single thread, and is accurate till a few last bits in the mantissa. Spoiler: see in the end how to increase performance to 0.87 billion logarithms per second.
Special cases:
Negative numbers, negative infinity and NaN
s with negative sign bit are treated as if they are very close to 0 (result in some garbage large negative "logarithm" values). Positive infinity and NaN
s with positive sign bit result in a logarithm around 1024. If you don't like how special cases are treated, one option is to add code that checks for them and does what suits you better. This will make the computation slower.
namespace {
// The limit is 19 because we process only high 32 bits of doubles, and out of
// 20 bits of mantissa there, 1 bit is used for rounding.
constexpr uint8_t cnLog2TblBits = 10; // 1024 numbers times 8 bytes = 8KB.
constexpr uint16_t cZeroExp = 1023;
const __m256i gDoubleNotExp = _mm256_set1_epi64x(~(0x7ffULL << 52));
const __m256d gDoubleExp0 = _mm256_castsi256_pd(_mm256_set1_epi64x(1023ULL << 52));
const __m256i cAvxExp2YMask = _mm256_set1_epi64x(
~((1ULL << (52-cnLog2TblBits)) - 1) );
const __m256d cPlusBit = _mm256_castsi256_pd(_mm256_set1_epi64x(
1ULL << (52 - cnLog2TblBits - 1)));
const __m256d gCommMul1 = _mm256_set1_pd(2.0 / 0.693147180559945309417); // 2.0/ln(2)
const __m256i gHigh32Permute = _mm256_set_epi32(0, 0, 0, 0, 7, 5, 3, 1);
const __m128i cSseMantTblMask = _mm_set1_epi32((1 << cnLog2TblBits) - 1);
const __m128i gExpNorm0 = _mm_set1_epi32(1023);
// plus |cnLog2TblBits|th highest mantissa bit
double gPlusLog2Table[1 << cnLog2TblBits];
} // anonymous namespace
void InitLog2Table() {
for(uint32_t i=0; i<(1<<cnLog2TblBits); i++) {
const uint64_t iZp = (uint64_t(cZeroExp) << 52)
| (uint64_t(i) << (52 - cnLog2TblBits)) | (1ULL << (52 - cnLog2TblBits - 1));
const double zp = *reinterpret_cast<const double*>(&iZp);
const double l2zp = std::log2(zp);
gPlusLog2Table[i] = l2zp;
}
}
__m256d __vectorcall Log2TblPlus(__m256d x) {
const __m256d zClearExp = _mm256_and_pd(_mm256_castsi256_pd(gDoubleNotExp), x);
const __m256d z = _mm256_or_pd(zClearExp, gDoubleExp0);
const __m128i high32 = _mm256_castsi256_si128(_mm256_permutevar8x32_epi32(
_mm256_castpd_si256(x), gHigh32Permute));
// This requires that x is non-negative, because the sign bit is not cleared before
// computing the exponent.
const __m128i exps32 = _mm_srai_epi32(high32, 20);
const __m128i normExps = _mm_sub_epi32(exps32, gExpNorm0);
// Compute y as approximately equal to log2(z)
const __m128i indexes = _mm_and_si128(cSseMantTblMask,
_mm_srai_epi32(high32, 20 - cnLog2TblBits));
const __m256d y = _mm256_i32gather_pd(gPlusLog2Table, indexes,
/*number of bytes per item*/ 8);
// Compute A as z/exp2(y)
const __m256d exp2_Y = _mm256_or_pd(
cPlusBit, _mm256_and_pd(z, _mm256_castsi256_pd(cAvxExp2YMask)));
// Calculate t=(A-1)/(A+1). Both numerator and denominator would be divided by exp2_Y
const __m256d tNum = _mm256_sub_pd(z, exp2_Y);
const __m256d tDen = _mm256_add_pd(z, exp2_Y);
// Compute the first polynomial term from "More efficient series" of https://en.wikipedia.org/wiki/Logarithm#Power_series
const __m256d t = _mm256_div_pd(tNum, tDen);
const __m256d log2_z = _mm256_fmadd_pd(t, gCommMul1, y);
// Leading integer part for the logarithm
const __m256d leading = _mm256_cvtepi32_pd(normExps);
const __m256d log2_x = _mm256_add_pd(log2_z, leading);
return log2_x;
}
It uses a combination of lookup table approach and a 1st degree polynomial, mostly described on Wikipedia (the link is in the code comments). I can afford to allocate 8KB of L1 cache here (which is a half of 16KB L1 cache available per logical core), because logarithm computation is really the bottleneck for me and there is not much more anything that needs L1 cache.
However, if you need more L1 cache for the other needs, you can decrease the amount of cache used by logarithm algorithm by reducing cnLog2TblBits
to e.g. 5 at expense of decreasing the accuracy of logarithm computation.
Or to keep the accuracy high, you can increase the number of polynomial terms by adding:
namespace {
// ...
const __m256d gCoeff1 = _mm256_set1_pd(1.0 / 3);
const __m256d gCoeff2 = _mm256_set1_pd(1.0 / 5);
const __m256d gCoeff3 = _mm256_set1_pd(1.0 / 7);
const __m256d gCoeff4 = _mm256_set1_pd(1.0 / 9);
const __m256d gCoeff5 = _mm256_set1_pd(1.0 / 11);
}
And then changing the tail of Log2TblPlus()
after line const __m256d t = _mm256_div_pd(tNum, tDen);
:
const __m256d t2 = _mm256_mul_pd(t, t); // t**2
const __m256d t3 = _mm256_mul_pd(t, t2); // t**3
const __m256d terms01 = _mm256_fmadd_pd(gCoeff1, t3, t);
const __m256d t5 = _mm256_mul_pd(t3, t2); // t**5
const __m256d terms012 = _mm256_fmadd_pd(gCoeff2, t5, terms01);
const __m256d t7 = _mm256_mul_pd(t5, t2); // t**7
const __m256d terms0123 = _mm256_fmadd_pd(gCoeff3, t7, terms012);
const __m256d t9 = _mm256_mul_pd(t7, t2); // t**9
const __m256d terms01234 = _mm256_fmadd_pd(gCoeff4, t9, terms0123);
const __m256d t11 = _mm256_mul_pd(t9, t2); // t**11
const __m256d terms012345 = _mm256_fmadd_pd(gCoeff5, t11, terms01234);
const __m256d log2_z = _mm256_fmadd_pd(terms012345, gCommMul1, y);
Then comment // Leading integer part for the logarithm
and the rest unchanged follow.
Normally you don't need that many terms, even for a few-bit table, I just provided the coefficients and computations for reference. It's likely that if cnLog2TblBits==5
, you won't need anything beyond terms012
. But I haven't done such measurements, you need to experiment what suits your needs.
The less polynomial terms you compute, obviously, the faster the computations are.
EDIT: this question In what situation would the AVX2 gather instructions be faster than individually loading the data? suggests that you may get a performance improvement if
const __m256d y = _mm256_i32gather_pd(gPlusLog2Table, indexes,
/*number of bytes per item*/ 8);
is replaced by
const __m256d y = _mm256_set_pd(gPlusLog2Table[indexes.m128i_u32[3]],
gPlusLog2Table[indexes.m128i_u32[2]],
gPlusLog2Table[indexes.m128i_u32[1]],
gPlusLog2Table[indexes.m128i_u32[0]]);
For my implementation it saves about 1.5 cycle, reducing the total cycle count to compute 4 logarithms from 18 to 16.5, thus the performance rises to 0.87 billion logarithms per second. I'm leaving the current implementation as is because it's more idiomatic and shoud be faster once the CPUs start doing gather
operations right (with coalescing like GPUs do).
EDIT2: on Ryzen CPU (but not on Intel) you can get a little more speedup (about 0.5 cycle) by replacing
const __m128i high32 = _mm256_castsi256_si128(_mm256_permutevar8x32_epi32(
_mm256_castpd_si256(x), gHigh32Permute));
with
const __m128 hiLane = _mm_castpd_ps(_mm256_extractf128_pd(x, 1));
const __m128 loLane = _mm_castpd_ps(_mm256_castpd256_pd128(x));
const __m128i high32 = _mm_castps_si128(_mm_shuffle_ps(loLane, hiLane,
_MM_SHUFFLE(3, 1, 3, 1)));