How to interact with each element of an ArrayType column in pyspark?
If I have an ArrayType column in pyspark
from pyspark.sql import SparkSession
spark = SparkSession.builder.getOrCreate()
df = spark.createDataFrame(((1,[]),(2,[1,2,3]),(3,[-2])),schema=StructType([StructField("a",IntegerType()),StructField("b",ArrayType(IntegerType()))]))
df.show()
output:
+---+---------+
| a| b|
+---+---------+
| 1| []|
| 2|[1, 2, 3]|
| 3| [-2]|
Now, I want to be able to interact with each element of column b
, Like,
- Divide each element by 5 output:
+---+---------------+
| a| b|
+---+---------------+
| 1| []|
| 2|[0.2, 0.4, 0.6]|
| 3| [-0.4]|
+---+---------------+
- Add to each element etc.
How do I go about such transformations where some operator or function is applied to each element of the array type columns?
Solution 1:
You are looking for the tranform
function. Transform enables to apply computation on each element of an array.
from pyspark.sql import functions as F
# Spark < 3.1.0
df.withColumn("b", F.expr("transform(b, x -> x / 5)")).show()
"""
+---+---------------+
| a| b|
+---+---------------+
| 1| []|
| 2|[0.2, 0.4, 0.6]|
| 3| [-0.4]|
+---+---------------+
"""
# Spark >= 3.1.0
df.withColumn("b", F.transform("b", lambda x: x / 5)).show()
"""
+---+---------------+
| a| b|
+---+---------------+
| 1| []|
| 2|[0.2, 0.4, 0.6]|
| 3| [-0.4]|
+---+---------------+
"""