how to calculate max value in some columns per row in pyspark
There is a function for that: pyspark.sql.functions.greatest
.
>>> df = spark.createDataFrame([(1, 4, 3)], ['a', 'b', 'c'])
>>> df.select(greatest(df.a, df.b, df.c).alias("greatest")).collect()
[Row(greatest=4)]
The example was taken directly from the docs.
(Least does the opposite.)
I think combing values to a list and than finding max on it would be the simplest approach.
from pyspark.sql.types import *
schema = StructType([
StructField("ClientId", IntegerType(), True),
StructField("m_ant21", IntegerType(), True),
StructField("m_ant22", IntegerType(), True),
StructField("m_ant23", IntegerType(), True),
StructField("m_ant24", IntegerType(), True)
])
df = spark\
.createDataFrame(
data=[(0, None, None, None, None),
(1, 23, 13, 17, 99),
(2, 0, 0, 0, 1),
(3, 0, None, 1, 0)],
schema=schema)
import pyspark.sql.functions as F
def agg_to_list(m21,m22,m23,m24):
return [m21,m22,m23,m24]
u_agg_to_list = F.udf(agg_to_list, ArrayType(IntegerType()))
df2 = df.withColumn('all_values', u_agg_to_list('m_ant21', 'm_ant22', 'm_ant23', 'm_ant24'))\
.withColumn('max', F.sort_array("all_values", False)[0])\
.select('ClientId', 'max')
df2.show()
Outputs :
+--------+----+
|ClientId|max |
+--------+----+
|0 |null|
|1 |99 |
|2 |1 |
|3 |1 |
+--------+----+