Spark: how to get the number of written rows?

Solution 1:

If you really want you can add custom listener and extract number of written rows from outputMetrics. Very simple example can look like this:

import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}

var recordsWrittenCount = 0L

sc.addSparkListener(new SparkListener() { 
  override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
    synchronized {
      recordsWrittenCount += taskEnd.taskMetrics.outputMetrics.recordsWritten 
    }
  }
})

sc.parallelize(1 to 10, 2).saveAsTextFile("/tmp/foobar")
recordsWrittenCount
// Long = 10

but this part of the API is intended for internal usage.

Solution 2:

The accepted answer more closely matches the OPs specific needs (as made explicit in various comments), nevertheless this answer will suit the majority.

The most efficient approach is to use an Accumulator: http://spark.apache.org/docs/latest/programming-guide.html#accumulators

val accum = sc.accumulator(0L)

data.map { x =>
  accum += 1
  x
}
.saveAsTextFile(path)

val count = accum.value

You can then wrap this in a useful pimp:

implicit class PimpedStringRDD(rdd: RDD[String]) {
  def saveAsTextFileAndCount(p: String): Long = {
    val accum = rdd.sparkContext.accumulator(0L)

    rdd.map { x =>
      accum += 1
      x
    }
    .saveAsTextFile(p)

    accum.value
  }
}

So you can do

val count = data.saveAsTextFileAndCount(path)

Solution 3:

If you look at

taskEnd.taskInfo.accumulables

You will see that it is bundled with following AccumulableInfo in ListBuffer in a sequential order.

AccumulableInfo(1,Some(internal.metrics.executorDeserializeTime),Some(33),Some(33),true,true,None), 
AccumulableInfo(2,Some(internal.metrics.executorDeserializeCpuTime),Some(32067956),Some(32067956),true,true,None), AccumulableInfo(3,Some(internal.metrics.executorRunTime),Some(325),Some(325),true,true,None), 
AccumulableInfo(4,Some(internal.metrics.executorCpuTime),Some(320581946),Some(320581946),true,true,None), 
AccumulableInfo(5,Some(internal.metrics.resultSize),Some(1459),Some(1459),true,true,None), 
AccumulableInfo(7,Some(internal.metrics.resultSerializationTime),Some(1),Some(1),true,true,None), 
AccumulableInfo(0,Some(number of output rows),Some(3),Some(3),true,true,Some(sql)

You can clearly see that number of output rows are on the 7th position of the listBuffer, so the correct way to get the rows being written count is

taskEnd.taskInfo.accumulables(6).value.get

We can get the rows written by following way ( I just modified @zero323's answer)

import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}

var recordsWrittenCount = 0L

sc.addSparkListener(new SparkListener() { 
  override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
    synchronized {
      recordsWrittenCount += taskEnd.taskInfo.accumulables(6).value.get.asInstanceOf[Long] 
    }
  }
})

sc.parallelize(1 to 10, 2).saveAsTextFile("/tmp/foobar")
recordsWrittenCount