How to understand the usage of #define directive in this code example?
I am looking into the #define
directive recently. And I am confused by the #define usage in the following code example. Anyone could explain how it works?
template <>
__inline__ __device__ void warpReduce<ReduceType::kSum, 2>(float* val_list) {
float val0_tmp, val1_tmp;
#define WarpReduceSumOneStep(a, b) \
val0_tmp = __shfl_xor_sync(FINAL_MASK, *(val_list + 0), a, b); \
val1_tmp = __shfl_xor_sync(FINAL_MASK, *(val_list + 1), a, b); \
*(val_list + 0) += val0_tmp; \
*(val_list + 1) += val1_tmp
WarpReduceSumOneStep(16, 32);
WarpReduceSumOneStep(8, 32);
WarpReduceSumOneStep(4, 32);
WarpReduceSumOneStep(2, 32);
WarpReduceSumOneStep(1, 32);
#undef WarpReduceSumOneStep
}
from my understanding, when WarpReduceSumOneStep(16, 32);
occurs, the compiler substitute it with the blocks between #define
and #undef
, right?
Every preprocessor directive takes up exactly one line. That \
at the ends of some of the lines say "pretend that this isn't really the end of a line". So the definition of the macro WarpReduceSumOneStep
includes the next four source lines. They're highlighted in blue on my system. The macro definition ends at the end of that last line, because it doesn't have a \
; the end of the line really is the end of the line.
After the end of the macro definition, any use of the macro WarpReduceSumOneStep
gets replaced by the text in the macro's definition. In the code in the question, that macro is used five times:
WarpReduceSumOneStep(16, 32);
WarpReduceSumOneStep(8, 32);
WarpReduceSumOneStep(4, 32);
WarpReduceSumOneStep(2, 32);
WarpReduceSumOneStep(1, 32);
Each of those is replaced by the text of the macro, with arguments replaced appropriately. I'm not going to go through all five; they're pretty much the same. The first one, WarpReduceSumOneStep(16, 32)
, becomes
val0_tmp = __shfl_xor_sync(FINAL_MASK, *(val_list + 0), 16, 32);
val1_tmp = __shfl_xor_sync(FINAL_MASK, *(val_list + 1), 16, 32);
*(val_list + 0) += val0_tmp;
*(val_list + 1) += val1_tmp
and because there's a semicolon after the inovocation of the macro, the full text becomes
val0_tmp = __shfl_xor_sync(FINAL_MASK, *(val_list + 0), 16, 32);
val1_tmp = __shfl_xor_sync(FINAL_MASK, *(val_list + 1), 16, 32);
*(val_list + 0) += val0_tmp;
*(val_list + 1) += val1_tmp;
After those five uses of the macro, the #undef
removes its definition. After that, any use of the name WarpReduceSumOneStep
is just a use of that name; the preprocessor won't do anything special with it.
Not exactly. #define sets up the 'keyword reference' of WarpReduceSumOneStep, which is the first part that is
WarpReduceSumOneStep(a, b) \
val0_tmp = __shfl_xor_sync(FINAL_MASK, *(val_list + 0), a, b); \
val1_tmp = __shfl_xor_sync(FINAL_MASK, *(val_list + 1), a, b); \
*(val_list + 0) += val0_tmp; \
*(val_list + 1) += val1_tmp
Then that section of code uses that expansion 5 times in the 16, 8, 4, 2, 1 uses. Then the #undef essentially removes that 'keyword reference' so that other parts of that file or linked source / header files could not use it.