How to understand the usage of #define directive in this code example?

I am looking into the #define directive recently. And I am confused by the #define usage in the following code example. Anyone could explain how it works?

template <>
__inline__ __device__ void warpReduce<ReduceType::kSum, 2>(float* val_list) {
  float val0_tmp, val1_tmp;
#define WarpReduceSumOneStep(a, b)                               \
  val0_tmp = __shfl_xor_sync(FINAL_MASK, *(val_list + 0), a, b); \
  val1_tmp = __shfl_xor_sync(FINAL_MASK, *(val_list + 1), a, b); \
  *(val_list + 0) += val0_tmp;                                   \
  *(val_list + 1) += val1_tmp

  WarpReduceSumOneStep(16, 32);
  WarpReduceSumOneStep(8, 32);
  WarpReduceSumOneStep(4, 32);
  WarpReduceSumOneStep(2, 32);
  WarpReduceSumOneStep(1, 32);

#undef WarpReduceSumOneStep
}

from my understanding, when WarpReduceSumOneStep(16, 32); occurs, the compiler substitute it with the blocks between #define and #undef, right?


Every preprocessor directive takes up exactly one line. That \ at the ends of some of the lines say "pretend that this isn't really the end of a line". So the definition of the macro WarpReduceSumOneStep includes the next four source lines. They're highlighted in blue on my system. The macro definition ends at the end of that last line, because it doesn't have a \; the end of the line really is the end of the line.

After the end of the macro definition, any use of the macro WarpReduceSumOneStep gets replaced by the text in the macro's definition. In the code in the question, that macro is used five times:

WarpReduceSumOneStep(16, 32);
WarpReduceSumOneStep(8, 32);
WarpReduceSumOneStep(4, 32);
WarpReduceSumOneStep(2, 32);
WarpReduceSumOneStep(1, 32);

Each of those is replaced by the text of the macro, with arguments replaced appropriately. I'm not going to go through all five; they're pretty much the same. The first one, WarpReduceSumOneStep(16, 32), becomes

val0_tmp = __shfl_xor_sync(FINAL_MASK, *(val_list + 0), 16, 32);
val1_tmp = __shfl_xor_sync(FINAL_MASK, *(val_list + 1), 16, 32);
*(val_list + 0) += val0_tmp;
*(val_list + 1) += val1_tmp

and because there's a semicolon after the inovocation of the macro, the full text becomes

val0_tmp = __shfl_xor_sync(FINAL_MASK, *(val_list + 0), 16, 32);
val1_tmp = __shfl_xor_sync(FINAL_MASK, *(val_list + 1), 16, 32);
*(val_list + 0) += val0_tmp;
*(val_list + 1) += val1_tmp;

After those five uses of the macro, the #undef removes its definition. After that, any use of the name WarpReduceSumOneStep is just a use of that name; the preprocessor won't do anything special with it.


Not exactly. #define sets up the 'keyword reference' of WarpReduceSumOneStep, which is the first part that is

WarpReduceSumOneStep(a, b)                               \
  val0_tmp = __shfl_xor_sync(FINAL_MASK, *(val_list + 0), a, b); \
  val1_tmp = __shfl_xor_sync(FINAL_MASK, *(val_list + 1), a, b); \
  *(val_list + 0) += val0_tmp;                                   \
  *(val_list + 1) += val1_tmp

Then that section of code uses that expansion 5 times in the 16, 8, 4, 2, 1 uses. Then the #undef essentially removes that 'keyword reference' so that other parts of that file or linked source / header files could not use it.