Spark: how to get the number of written rows?
Solution 1:
If you really want you can add custom listener and extract number of written rows from outputMetrics
. Very simple example can look like this:
import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
var recordsWrittenCount = 0L
sc.addSparkListener(new SparkListener() {
override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
synchronized {
recordsWrittenCount += taskEnd.taskMetrics.outputMetrics.recordsWritten
}
}
})
sc.parallelize(1 to 10, 2).saveAsTextFile("/tmp/foobar")
recordsWrittenCount
// Long = 10
but this part of the API is intended for internal usage.
Solution 2:
The accepted answer more closely matches the OPs specific needs (as made explicit in various comments), nevertheless this answer will suit the majority.
The most efficient approach is to use an Accumulator: http://spark.apache.org/docs/latest/programming-guide.html#accumulators
val accum = sc.accumulator(0L)
data.map { x =>
accum += 1
x
}
.saveAsTextFile(path)
val count = accum.value
You can then wrap this in a useful pimp:
implicit class PimpedStringRDD(rdd: RDD[String]) {
def saveAsTextFileAndCount(p: String): Long = {
val accum = rdd.sparkContext.accumulator(0L)
rdd.map { x =>
accum += 1
x
}
.saveAsTextFile(p)
accum.value
}
}
So you can do
val count = data.saveAsTextFileAndCount(path)
Solution 3:
If you look at
taskEnd.taskInfo.accumulables
You will see that it is bundled with following AccumulableInfo
in ListBuffer
in a sequential order.
AccumulableInfo(1,Some(internal.metrics.executorDeserializeTime),Some(33),Some(33),true,true,None),
AccumulableInfo(2,Some(internal.metrics.executorDeserializeCpuTime),Some(32067956),Some(32067956),true,true,None), AccumulableInfo(3,Some(internal.metrics.executorRunTime),Some(325),Some(325),true,true,None),
AccumulableInfo(4,Some(internal.metrics.executorCpuTime),Some(320581946),Some(320581946),true,true,None),
AccumulableInfo(5,Some(internal.metrics.resultSize),Some(1459),Some(1459),true,true,None),
AccumulableInfo(7,Some(internal.metrics.resultSerializationTime),Some(1),Some(1),true,true,None),
AccumulableInfo(0,Some(number of output rows),Some(3),Some(3),true,true,Some(sql)
You can clearly see that number of output rows are on the 7th position of the listBuffer, so the correct way to get the rows being written count is
taskEnd.taskInfo.accumulables(6).value.get
We can get the rows written by following way ( I just modified @zero323's answer)
import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
var recordsWrittenCount = 0L
sc.addSparkListener(new SparkListener() {
override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
synchronized {
recordsWrittenCount += taskEnd.taskInfo.accumulables(6).value.get.asInstanceOf[Long]
}
}
})
sc.parallelize(1 to 10, 2).saveAsTextFile("/tmp/foobar")
recordsWrittenCount